diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index 2596a18c..d21d0684 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -11,8 +11,8 @@ jobs: test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-go@v6 with: go-version-file: backend/go.mod check-latest: false @@ -30,8 +30,8 @@ jobs: golangci-lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-go@v6 with: go-version-file: backend/go.mod check-latest: false @@ -43,5 +43,5 @@ jobs: uses: golangci/golangci-lint-action@v9 with: version: v2.7 - args: --timeout=5m - working-directory: backend + args: --timeout=30m + working-directory: backend \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 50bb73e0..a1c6aa23 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -31,7 +31,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Update VERSION file run: | @@ -45,7 +45,7 @@ jobs: echo "Updated VERSION file to: $VERSION" - name: Upload VERSION artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: version-file path: backend/cmd/server/VERSION @@ -55,7 +55,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup pnpm uses: pnpm/action-setup@v4 @@ -63,7 +63,7 @@ jobs: version: 9 - name: Setup Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version: '20' cache: 'pnpm' @@ -78,7 +78,7 @@ jobs: working-directory: frontend - name: Upload frontend artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: frontend-dist path: backend/internal/web/dist/ @@ -89,25 +89,25 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.inputs.tag || github.ref }} - name: Download VERSION artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v8 with: name: version-file path: backend/cmd/server/ - name: Download frontend artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v8 with: name: frontend-dist path: backend/internal/web/dist/ - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version-file: backend/go.mod check-latest: false @@ -173,7 +173,7 @@ jobs: run: echo "owner=$(echo '${{ github.repository_owner }}' | tr '[:upper:]' '[:lower:]')" >> $GITHUB_OUTPUT - name: Run GoReleaser - uses: goreleaser/goreleaser-action@v6 + uses: goreleaser/goreleaser-action@v7 with: version: '~> v2' args: release --clean --skip=validate ${{ env.SIMPLE_RELEASE == 'true' && '--config=.goreleaser.simple.yaml' || '' }} @@ -188,7 +188,7 @@ jobs: # Update DockerHub description - name: Update DockerHub description if: ${{ env.SIMPLE_RELEASE != 'true' && env.DOCKERHUB_USERNAME != '' }} - uses: peter-evans/dockerhub-description@v4 + uses: peter-evans/dockerhub-description@v5 env: DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} with: diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml index 05dd1d1a..db922509 100644 --- a/.github/workflows/security-scan.yml +++ b/.github/workflows/security-scan.yml @@ -12,10 +12,11 @@ permissions: jobs: backend-security: runs-on: ubuntu-latest + timeout-minutes: 15 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version-file: backend/go.mod check-latest: false @@ -28,22 +29,17 @@ jobs: run: | go install golang.org/x/vuln/cmd/govulncheck@latest govulncheck ./... - - name: Run gosec - working-directory: backend - run: | - go install github.com/securego/gosec/v2/cmd/gosec@latest - gosec -severity high -confidence high ./... frontend-security: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up pnpm uses: pnpm/action-setup@v4 with: version: 9 - name: Set up Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version: '20' cache: 'pnpm' diff --git a/.gitignore b/.gitignore index 48172982..297c1d6f 100644 --- a/.gitignore +++ b/.gitignore @@ -116,17 +116,20 @@ backend/.installed # =================== tests CLAUDE.md -AGENTS.md .claude scripts .code-review-state -openspec/ -docs/ +#openspec/ code-reviews/ -AGENTS.md +#AGENTS.md backend/cmd/server/server deploy/docker-compose.override.yml .gocache/ vite.config.js docs/* -.serena/ \ No newline at end of file +.serena/ +.codex/ +frontend/coverage/ +aicodex +output/ + diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..bb5bb465 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,105 @@ +# Repository Guidelines + +## Project Structure & Module Organization +- `backend/`: Go service. `cmd/server` is the entrypoint, `internal/` contains handlers/services/repositories/server wiring, `ent/` holds Ent schemas and generated ORM code, `migrations/` stores DB migrations, and `internal/web/dist/` is the embedded frontend build output. +- `frontend/`: Vue 3 + TypeScript app. Main folders are `src/api`, `src/components`, `src/views`, `src/stores`, `src/composables`, `src/utils`, and test files in `src/**/__tests__`. +- `deploy/`: Docker and deployment assets (`docker-compose*.yml`, `.env.example`, `config.example.yaml`). +- `openspec/`: Spec-driven change docs (`changes//{proposal,design,tasks}.md`). +- `tools/`: Utility scripts (security/perf checks). + +## Build, Test, and Development Commands +```bash +make build # Build backend + frontend +make test # Backend tests + frontend lint/typecheck +cd backend && make build # Build backend binary +cd backend && make test-unit # Go unit tests +cd backend && make test-integration # Go integration tests +cd backend && make test # go test ./... + golangci-lint +cd frontend && pnpm install --frozen-lockfile +cd frontend && pnpm dev # Vite dev server +cd frontend && pnpm build # Type-check + production build +cd frontend && pnpm test:run # Vitest run +cd frontend && pnpm test:coverage # Vitest + coverage report +python3 tools/secret_scan.py # Secret scan +``` + +## Coding Style & Naming Conventions +- Go: format with `gofmt`; lint with `golangci-lint` (`backend/.golangci.yml`). +- Respect layering: `internal/service` and `internal/handler` must not import `internal/repository`, `gorm`, or `redis` directly (enforced by depguard). +- Frontend: Vue SFC + TypeScript, 2-space indentation, ESLint rules from `frontend/.eslintrc.cjs`. +- Naming: components use `PascalCase.vue`, composables use `useXxx.ts`, Go tests use `*_test.go`, frontend tests use `*.spec.ts`. + +## Go & Frontend Development Standards +- Control branch complexity: `if` nesting must not exceed 3 levels. Refactor with guard clauses, early returns, helper functions, or strategy maps when deeper logic appears. +- JSON hot-path rule: for read-only/partial-field extraction, prefer `gjson` over full `encoding/json` struct unmarshal to reduce allocations and improve latency. +- Exception rule: if full schema validation or typed writes are required, `encoding/json` is allowed, but PR must explain why `gjson` is not suitable. + +### Go Performance Rules +- Optimization workflow rule: benchmark/profile first, then optimize. Use `go test -bench`, `go tool pprof`, and runtime diagnostics before changing hot-path code. +- For hot functions, run escape analysis (`go build -gcflags=all='-m -m'`) and prioritize stack allocation where reasonable. +- Every external I/O path must use `context.Context` with explicit timeout/cancel. +- When creating derived contexts (`WithTimeout` / `WithDeadline`), always `defer cancel()` to release resources. +- Preallocate slices/maps when size can be estimated (`make([]T, 0, n)`, `make(map[K]V, n)`). +- Avoid unnecessary allocations in loops; reuse buffers and prefer `strings.Builder`/`bytes.Buffer`. +- Prohibit N+1 query patterns; batch DB/Redis operations and verify indexes for new query paths. +- For hot-path changes, include benchmark or latency comparison evidence (e.g., `go test -bench` before/after). +- Keep goroutine growth bounded (worker pool/semaphore), and avoid unbounded fan-out. +- Lock minimization rule: if a lock can be avoided, do not use a lock. Prefer ownership transfer (channel), sharding, immutable snapshots, copy-on-write, or atomic operations to reduce contention. +- When locks are unavoidable, keep critical sections minimal, avoid nested locks, and document why lock-free alternatives are not feasible. +- Follow `sync` guidance: prefer channels for higher-level synchronization; use low-level mutex primitives only where necessary. +- Avoid reflection and `interface{}`-heavy conversions in hot paths; use typed structs/functions. +- Use `sync.Pool` only when benchmark proves allocation reduction; remove if no measurable gain. +- Avoid repeated `time.Now()`/`fmt.Sprintf` in tight loops; hoist or cache when possible. +- For stable high-traffic binaries, maintain representative `default.pgo` profiles and keep `go build -pgo=auto` enabled. + +### Data Access & Cache Rules +- Every new/changed SQL query must be checked with `EXPLAIN` (or `EXPLAIN ANALYZE` in staging) and include index rationale in PR. +- Default to keyset pagination for large tables; avoid deep `OFFSET` scans on hot endpoints. +- Query only required columns; prohibit broad `SELECT *` in latency-sensitive paths. +- Keep transactions short; never perform external RPC/network calls inside DB transactions. +- Connection pool must be explicitly tuned and observed via `DB.Stats` (`SetMaxOpenConns`, `SetMaxIdleConns`, `SetConnMaxIdleTime`, `SetConnMaxLifetime`). +- Avoid overly small `MaxOpenConns` that can turn DB access into lock/semaphore bottlenecks. +- Cache keys must be versioned (e.g., `user_usage:v2:{id}`) and TTL should include jitter to avoid thundering herd. +- Use request coalescing (`singleflight` or equivalent) for high-concurrency cache miss paths. + +### Frontend Performance Rules +- Route-level and heavy-module code splitting is required; lazy-load non-critical views/components. +- API requests must support cancellation and deduplication; use debounce/throttle for search-like inputs. +- Minimize unnecessary reactivity: avoid deep watch chains when computed/cache can solve it. +- Prefer stable props and selective rendering controls (`v-once`, `v-memo`) for expensive subtrees when data is static or keyed. +- Large data rendering must use pagination or virtualization (especially tables/lists >200 rows). +- Move expensive CPU work off the main thread (Web Worker) or chunk tasks to avoid UI blocking. +- Keep bundle growth controlled; avoid adding heavy dependencies without clear ROI and alternatives review. +- Avoid expensive inline computations in templates; move to cached `computed` selectors. +- Keep state normalized; avoid duplicated derived state across multiple stores/components. +- Load charts/editors/export libraries on demand only (`dynamic import`) instead of app-entry import. +- Core Web Vitals targets (p75): `LCP <= 2.5s`, `INP <= 200ms`, `CLS <= 0.1`. +- Main-thread task budget: keep individual tasks below ~50ms; split long tasks and yield between chunks. +- Enforce frontend budgets in CI (Lighthouse CI with `budget.json`) for critical routes. + +### Performance Budget & PR Evidence +- Performance budget is mandatory for hot-path PRs: backend p95/p99 latency and CPU/memory must not regress by more than 5% versus baseline. +- Frontend budget: new route-level JS should not increase by more than 30KB gzip without explicit approval. +- For any gateway/protocol hot path, attach a reproducible benchmark command and results (input size, concurrency, before/after table). +- Profiling evidence is required for major optimizations (`pprof`, flamegraph, browser performance trace, or bundle analyzer output). + +### Quality Gate +- Any changed code must include new or updated unit tests. +- Coverage must stay above 85% (global frontend threshold and no regressions for touched backend modules). +- If any rule is intentionally violated, document reason, risk, and mitigation in the PR description. + +## Testing Guidelines +- Backend suites: `go test -tags=unit ./...`, `go test -tags=integration ./...`, and e2e where relevant. +- Frontend uses Vitest (`jsdom`); keep tests near modules (`__tests__`) or as `*.spec.ts`. +- Enforce unit-test and coverage rules defined in `Quality Gate`. +- Before opening a PR, run `make test` plus targeted tests for touched areas. + +## Commit & Pull Request Guidelines +- Follow Conventional Commits: `feat(scope): ...`, `fix(scope): ...`, `chore(scope): ...`, `docs(scope): ...`. +- PRs should include a clear summary, linked issue/spec, commands run for verification, and screenshots/GIFs for UI changes. +- For behavior/API changes, add or update `openspec/changes/...` artifacts. +- If dependencies change, commit `frontend/pnpm-lock.yaml` in the same PR. + +## Security & Configuration Tips +- Use `deploy/.env.example` and `deploy/config.example.yaml` as templates; do not commit real credentials. +- Set stable `JWT_SECRET`, `TOTP_ENCRYPTION_KEY`, and strong database passwords outside local dev. diff --git a/DEV_GUIDE.md b/DEV_GUIDE.md new file mode 100644 index 00000000..d0d362e0 --- /dev/null +++ b/DEV_GUIDE.md @@ -0,0 +1,346 @@ +# sub2api 项目开发指南 + +> 本文档记录项目环境配置、常见坑点和注意事项,供 Claude Code 和团队成员参考。 + +## 一、项目基本信息 + +| 项目 | 说明 | +|------|------| +| **上游仓库** | Wei-Shaw/sub2api | +| **Fork 仓库** | bayma888/sub2api-bmai | +| **技术栈** | Go 后端 (Ent ORM + Gin) + Vue3 前端 (pnpm) | +| **数据库** | PostgreSQL 16 + Redis | +| **包管理** | 后端: go modules, 前端: **pnpm**(不是 npm) | + +## 二、本地环境配置 + +### PostgreSQL 16 (Windows 服务) + +| 配置项 | 值 | +|--------|-----| +| 端口 | 5432 | +| psql 路径 | `C:\Program Files\PostgreSQL\16\bin\psql.exe` | +| pg_hba.conf | `C:\Program Files\PostgreSQL\16\data\pg_hba.conf` | +| 数据库凭据 | user=`sub2api`, password=`sub2api`, dbname=`sub2api` | +| 超级用户 | user=`postgres`, password=`postgres` | + +### Redis + +| 配置项 | 值 | +|--------|-----| +| 端口 | 6379 | +| 密码 | 无 | + +### 开发工具 + +```bash +# golangci-lint v2.7 +go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.7 + +# pnpm (前端包管理) +npm install -g pnpm +``` + +## 三、CI/CD 流水线 + +### GitHub Actions Workflows + +| Workflow | 触发条件 | 检查内容 | +|----------|----------|----------| +| **backend-ci.yml** | push, pull_request | 单元测试 + 集成测试 + golangci-lint v2.7 | +| **security-scan.yml** | push, pull_request, 每周一 | govulncheck + gosec + pnpm audit | +| **release.yml** | tag `v*` | 构建发布(PR 不触发) | + +### CI 要求 + +- Go 版本必须是 **1.25.7** +- 前端使用 `pnpm install --frozen-lockfile`,必须提交 `pnpm-lock.yaml` + +### 本地测试命令 + +```bash +# 后端单元测试 +cd backend && go test -tags=unit ./... + +# 后端集成测试 +cd backend && go test -tags=integration ./... + +# 代码质量检查 +cd backend && golangci-lint run ./... + +# 前端依赖安装(必须用 pnpm) +cd frontend && pnpm install +``` + +## 四、常见坑点 & 解决方案 + +### 坑 1:pnpm-lock.yaml 必须同步提交 + +**问题**:`package.json` 新增依赖后,CI 的 `pnpm install --frozen-lockfile` 失败。 + +**原因**:上游 CI 使用 pnpm,lock 文件不同步会报错。 + +**解决**: +```bash +cd frontend +pnpm install # 更新 pnpm-lock.yaml +git add pnpm-lock.yaml +git commit -m "chore: update pnpm-lock.yaml" +``` + +--- + +### 坑 2:npm 和 pnpm 的 node_modules 冲突 + +**问题**:之前用 npm 装过 `node_modules`,pnpm install 报 `EPERM` 错误。 + +**解决**: +```bash +cd frontend +rm -rf node_modules # 或 PowerShell: Remove-Item -Recurse -Force node_modules +pnpm install +``` + +--- + +### 坑 3:PowerShell 中 bcrypt hash 的 `$` 被转义 + +**问题**:bcrypt hash 格式如 `$2a$10$xxx...`,PowerShell 把 `$2a` 当变量解析,导致数据丢失。 + +**解决**:将 SQL 写入文件,用 `psql -f` 执行: +```bash +# 错误示范(PowerShell 会吃掉 $) +psql -c "INSERT INTO users ... VALUES ('$2a$10$...')" + +# 正确做法 +echo "INSERT INTO users ... VALUES ('\$2a\$10\$...')" > temp.sql +psql -U sub2api -h 127.0.0.1 -d sub2api -f temp.sql +``` + +--- + +### 坑 4:psql 不支持中文路径 + +**问题**:`psql -f "D:\中文路径\file.sql"` 报错找不到文件。 + +**解决**:复制到纯英文路径再执行: +```bash +cp "D:\中文路径\file.sql" "C:\temp.sql" +psql -f "C:\temp.sql" +``` + +--- + +### 坑 5:PostgreSQL 密码重置流程 + +**场景**:忘记 PostgreSQL 密码。 + +**步骤**: +1. 修改 `C:\Program Files\PostgreSQL\16\data\pg_hba.conf` + ``` + # 将 scram-sha-256 改为 trust + host all all 127.0.0.1/32 trust + ``` +2. 重启 PostgreSQL 服务 + ```powershell + Restart-Service postgresql-x64-16 + ``` +3. 无密码登录并重置 + ```bash + psql -U postgres -h 127.0.0.1 + ALTER USER sub2api WITH PASSWORD 'sub2api'; + ALTER USER postgres WITH PASSWORD 'postgres'; + ``` +4. 改回 `scram-sha-256` 并重启 + +--- + +### 坑 6:Go interface 新增方法后 test stub 必须补全 + +**问题**:给 interface 新增方法后,编译报错 `does not implement interface (missing method XXX)`。 + +**原因**:所有测试文件中实现该 interface 的 stub/mock 都必须补上新方法。 + +**解决**: +```bash +# 搜索所有实现该 interface 的 struct +cd backend +grep -r "type.*Stub.*struct" internal/ +grep -r "type.*Mock.*struct" internal/ + +# 逐一补全新方法 +``` + +--- + +### 坑 7:Windows 上 psql 连 localhost 的 IPv6 问题 + +**问题**:psql 连 `localhost` 先尝试 IPv6 (::1),可能报错后再回退 IPv4。 + +**建议**:直接用 `127.0.0.1` 代替 `localhost`。 + +--- + +### 坑 8:Windows 没有 make 命令 + +**问题**:CI 里用 `make test-unit`,本地 Windows 没有 make。 + +**解决**:直接用 Makefile 里的原始命令: +```bash +# 代替 make test-unit +go test -tags=unit ./... + +# 代替 make test-integration +go test -tags=integration ./... +``` + +--- + +### 坑 9:Ent Schema 修改后必须重新生成 + +**问题**:修改 `ent/schema/*.go` 后,代码不生效。 + +**解决**: +```bash +cd backend +go generate ./ent # 重新生成 ent 代码 +git add ent/ # 生成的文件也要提交 +``` + +--- + +### 坑 10:前端测试看似正常,但后端调用失败(模型映射被批量误改) + +**典型现象**: +- 前端按钮点测看起来正常; +- 实际通过 API/客户端调用时返回 `Service temporarily unavailable` 或提示无可用账号; +- 常见于 OpenAI 账号(例如 Codex 模型)在批量修改后突然不可用。 + +**根因**: +- OpenAI 账号编辑页默认不显式展示映射规则,容易让人误以为“没映射也没关系”; +- 但在**批量修改同时选中不同平台账号**(OpenAI + Antigravity/Gemini)时,模型白名单/映射可能被跨平台策略覆盖; +- 结果是 OpenAI 账号的关键模型映射丢失或被改坏,后端选不到可用账号。 + +**修复方案(按优先级)**: +1. **快速修复(推荐)**:在批量修改中补回正确的透传映射(例如 `gpt-5.3-codex -> gpt-5.3-codex-spark`)。 +2. **彻底重建**:删除并重新添加全部相关账号(最稳但成本高)。 + +**关键经验**: +- 如果某模型已被软件内置默认映射覆盖,通常不需要额外再加透传; +- 但当上游模型更新快于本仓库默认映射时,**手动批量添加透传映射**是最简单、最低风险的临时兜底方案; +- 批量操作前尽量按平台分组,不要混选不同平台账号。 + +--- + +### 坑 11:PR 提交前检查清单 + +提交 PR 前务必本地验证: + +- [ ] `go test -tags=unit ./...` 通过 +- [ ] `go test -tags=integration ./...` 通过 +- [ ] `golangci-lint run ./...` 无新增问题 +- [ ] `pnpm-lock.yaml` 已同步(如果改了 package.json) +- [ ] 所有 test stub 补全新接口方法(如果改了 interface) +- [ ] Ent 生成的代码已提交(如果改了 schema) + +## 五、常用命令速查 + +### 数据库操作 + +```bash +# 连接数据库 +psql -U sub2api -h 127.0.0.1 -d sub2api + +# 查看所有用户 +psql -U postgres -h 127.0.0.1 -c "\du" + +# 查看所有数据库 +psql -U postgres -h 127.0.0.1 -c "\l" + +# 执行 SQL 文件 +psql -U sub2api -h 127.0.0.1 -d sub2api -f migration.sql +``` + +### Git 操作 + +```bash +# 同步上游 +git fetch upstream +git checkout main +git merge upstream/main +git push origin main + +# 创建功能分支 +git checkout -b feature/xxx + +# Rebase 到最新 main +git fetch upstream +git rebase upstream/main +``` + +### 前端操作 + +```bash +# 安装依赖(必须用 pnpm) +cd frontend +pnpm install + +# 开发服务器 +pnpm dev + +# 构建 +pnpm build +``` + +### 后端操作 + +```bash +# 运行服务器 +cd backend +go run ./cmd/server/ + +# 生成 Ent 代码 +go generate ./ent + +# 运行测试 +go test -tags=unit ./... +go test -tags=integration ./... + +# Lint 检查 +golangci-lint run ./... +``` + +## 六、项目结构速览 + +``` +sub2api-bmai/ +├── backend/ +│ ├── cmd/server/ # 主程序入口 +│ ├── ent/ # Ent ORM 生成代码 +│ │ └── schema/ # 数据库 Schema 定义 +│ ├── internal/ +│ │ ├── handler/ # HTTP 处理器 +│ │ ├── service/ # 业务逻辑 +│ │ ├── repository/ # 数据访问层 +│ │ └── server/ # 服务器配置 +│ ├── migrations/ # 数据库迁移脚本 +│ └── config.yaml # 配置文件 +├── frontend/ +│ ├── src/ +│ │ ├── api/ # API 调用 +│ │ ├── components/ # Vue 组件 +│ │ ├── views/ # 页面视图 +│ │ ├── types/ # TypeScript 类型 +│ │ └── i18n/ # 国际化 +│ ├── package.json # 依赖配置 +│ └── pnpm-lock.yaml # pnpm 锁文件(必须提交) +└── .claude/ + └── CLAUDE.md # 本文档 +``` + +## 七、参考资源 + +- [上游仓库](https://github.com/Wei-Shaw/sub2api) +- [Ent 文档](https://entgo.io/docs/getting-started) +- [Vue3 文档](https://vuejs.org/) +- [pnpm 文档](https://pnpm.io/) diff --git a/Dockerfile b/Dockerfile index c9fcf301..1493e8a7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,7 +8,7 @@ ARG NODE_IMAGE=node:24-alpine ARG GOLANG_IMAGE=golang:1.25.7-alpine -ARG ALPINE_IMAGE=alpine:3.20 +ARG ALPINE_IMAGE=alpine:3.21 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn @@ -36,7 +36,7 @@ RUN pnpm run build FROM ${GOLANG_IMAGE} AS backend-builder # Build arguments for version info (set by CI) -ARG VERSION=docker +ARG VERSION= ARG COMMIT=docker ARG DATE ARG GOPROXY @@ -61,9 +61,14 @@ COPY backend/ ./ COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist # Build the binary (BuildType=release for CI builds, embed frontend) -RUN CGO_ENABLED=0 GOOS=linux go build \ +# Version precedence: build arg VERSION > cmd/server/VERSION +RUN VERSION_VALUE="${VERSION}" && \ + if [ -z "${VERSION_VALUE}" ]; then VERSION_VALUE="$(tr -d '\r\n' < ./cmd/server/VERSION)"; fi && \ + DATE_VALUE="${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)}" && \ + CGO_ENABLED=0 GOOS=linux go build \ -tags embed \ - -ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \ + -ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \ + -trimpath \ -o /app/sub2api \ ./cmd/server @@ -81,7 +86,6 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" RUN apk add --no-cache \ ca-certificates \ tzdata \ - curl \ && rm -rf /var/cache/apk/* # Create non-root user @@ -91,11 +95,12 @@ RUN addgroup -g 1000 sub2api && \ # Set working directory WORKDIR /app -# Copy binary from builder -COPY --from=backend-builder /app/sub2api /app/sub2api +# Copy binary/resources with ownership to avoid extra full-layer chown copy +COPY --from=backend-builder --chown=sub2api:sub2api /app/sub2api /app/sub2api +COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/resources # Create data directory -RUN mkdir -p /app/data && chown -R sub2api:sub2api /app +RUN mkdir -p /app/data && chown sub2api:sub2api /app/data # Switch to non-root user USER sub2api @@ -105,7 +110,7 @@ EXPOSE 8080 # Health check HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ - CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1 + CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1 # Run the application ENTRYPOINT ["/app/sub2api"] diff --git a/Makefile b/Makefile index a5e18a37..fd6a5a9a 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build build-backend build-frontend test test-backend test-frontend +.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan # 一键编译前后端 build: build-backend build-frontend @@ -11,6 +11,10 @@ build-backend: build-frontend: @pnpm --dir frontend run build +# 编译 datamanagementd(宿主机数据管理进程) +build-datamanagementd: + @cd datamanagement && go build -o datamanagementd ./cmd/datamanagementd + # 运行测试(后端 + 前端) test: test-backend test-frontend @@ -20,3 +24,9 @@ test-backend: test-frontend: @pnpm --dir frontend run lint:check @pnpm --dir frontend run typecheck + +test-datamanagementd: + @cd datamanagement && go test ./... + +secret-scan: + @python3 tools/secret_scan.py diff --git a/README.md b/README.md index 36949b0a..1e2f2290 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot ## Documentation - Dependency Security: `docs/dependency-security.md` +- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md` --- @@ -363,6 +364,12 @@ default: rate_multiplier: 1.0 ``` +### Sora Status (Temporarily Unavailable) + +> ⚠️ Sora-related features are temporarily unavailable due to technical issues in upstream integration and media delivery. +> Please do not rely on Sora in production at this time. +> Existing `gateway.sora_*` configuration keys are reserved and may not take effect until these issues are resolved. + Additional security-related options are available in `config.yaml`: - `cors.allowed_origins` for CORS allowlist diff --git a/README_CN.md b/README_CN.md index 1e0d1d62..9da089b7 100644 --- a/README_CN.md +++ b/README_CN.md @@ -62,8 +62,6 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( - 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。 - 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。 ---- - ## 部署方式 ### 方式一:脚本安装(推荐) @@ -139,6 +137,8 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install 使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。 +如果你的服务器是 **Ubuntu 24.04**,建议直接参考:`deploy/ubuntu24-docker-compose-aicodex.md`,其中包含「安装最新版 Docker + docker-compose-aicodex.yml 部署」的完整步骤。 + #### 前置条件 - Docker 20.10+ @@ -244,6 +244,18 @@ docker-compose -f docker-compose.local.yml logs -f sub2api **推荐:** 使用 `docker-compose.local.yml`(脚本部署)以便更轻松地管理数据。 +#### 启用“数据管理”功能(datamanagementd) + +如需启用管理后台“数据管理”,需要额外部署宿主机数据管理进程 `datamanagementd`。 + +关键点: + +- 主进程固定探测:`/tmp/sub2api-datamanagement.sock` +- 只有该 Socket 可连通时,数据管理功能才会开启 +- Docker 场景需将宿主机 Socket 挂载到容器同路径 + +详细部署步骤见:`deploy/DATAMANAGEMENTD_CN.md` + #### 访问 在浏览器中打开 `http://你的服务器IP:8080` @@ -370,6 +382,33 @@ default: rate_multiplier: 1.0 ``` +### Sora 功能状态(暂不可用) + +> ⚠️ 当前 Sora 相关功能因上游接入与媒体链路存在技术问题,暂时不可用。 +> 现阶段请勿在生产环境依赖 Sora 能力。 +> 文档中的 `gateway.sora_*` 配置仅作预留,待技术问题修复后再恢复可用。 + +### Sora 媒体签名 URL(功能恢复后可选) + +当配置 `gateway.sora_media_signing_key` 且 `gateway.sora_media_signed_url_ttl_seconds > 0` 时,网关会将 Sora 输出的媒体地址改写为临时签名 URL(`/sora/media-signed/...`)。这样无需 API Key 即可在浏览器中直接访问,且具备过期控制与防篡改能力(签名包含 path + query)。 + +```yaml +gateway: + # /sora/media 是否强制要求 API Key(默认 false) + sora_media_require_api_key: false + # 媒体临时签名密钥(为空则禁用签名) + sora_media_signing_key: "your-signing-key" + # 临时签名 URL 有效期(秒) + sora_media_signed_url_ttl_seconds: 900 +``` + +> 若未配置签名密钥,`/sora/media-signed` 将返回 503。 +> 如需更严格的访问控制,可将 `sora_media_require_api_key` 设为 true,仅允许携带 API Key 的 `/sora/media` 访问。 + +访问策略说明: +- `/sora/media`:内部调用或客户端携带 API Key 才能下载 +- `/sora/media-signed`:外部可访问,但有签名 + 过期控制 + `config.yaml` 还支持以下安全相关配置: - `cors.allowed_origins` 配置 CORS 白名单 @@ -383,6 +422,14 @@ default: - `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For - `turnstile.required` 在 release 模式强制启用 Turnstile +**网关防御纵深建议(重点)** + +- `gateway.upstream_response_read_max_bytes`:限制非流式上游响应读取大小(默认 `8MB`),用于防止异常响应导致内存放大。 +- `gateway.proxy_probe_response_read_max_bytes`:限制代理探测响应读取大小(默认 `1MB`)。 +- `gateway.gemini_debug_response_headers`:默认 `false`,仅在排障时短时开启,避免高频请求日志开销。 +- `/auth/register`、`/auth/login`、`/auth/login/2fa`、`/auth/send-verify-code` 已提供服务端兜底限流(Redis 故障时 fail-close)。 +- 推荐将 WAF/CDN 作为第一层防护,服务端限流与响应读取上限作为第二层兜底;两层同时保留,避免旁路流量与误配置风险。 + **⚠️ 安全警告:HTTP URL 配置** 当 `security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL(例如用于开发或内网测试),必须显式设置: @@ -428,6 +475,29 @@ Invalid base URL: invalid url scheme: http ./sub2api ``` +#### HTTP/2 (h2c) 与 HTTP/1.1 回退 + +后端明文端口默认支持 h2c,并保留 HTTP/1.1 回退用于 WebSocket 与旧客户端。浏览器通常不支持 h2c,性能收益主要在反向代理或内网链路。 + +**反向代理示例(Caddy):** + +```caddyfile +transport http { + versions h2c h1 +} +``` + +**验证:** + +```bash +# h2c prior knowledge +curl --http2-prior-knowledge -I http://localhost:8080/health +# HTTP/1.1 回退 +curl --http1.1 -I http://localhost:8080/health +# WebSocket 回退验证(需管理员 token) +websocat -H="Sec-WebSocket-Protocol: sub2api-admin, jwt." ws://localhost:8080/api/v1/admin/ops/ws/qps +``` + #### 开发模式 ```bash diff --git a/backend/.golangci.yml b/backend/.golangci.yml index 3ec692a8..68b76751 100644 --- a/backend/.golangci.yml +++ b/backend/.golangci.yml @@ -5,6 +5,7 @@ linters: enable: - depguard - errcheck + - gosec - govet - ineffassign - staticcheck @@ -42,6 +43,22 @@ linters: desc: "handler must not import gorm" - pkg: github.com/redis/go-redis/v9 desc: "handler must not import redis" + gosec: + excludes: + - G101 + - G103 + - G104 + - G109 + - G115 + - G201 + - G202 + - G301 + - G302 + - G304 + - G306 + - G404 + severity: high + confidence: high errcheck: # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. # Such cases aren't reported by default. diff --git a/backend/Makefile b/backend/Makefile index 6a5d2caa..7084ccb9 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -1,7 +1,14 @@ -.PHONY: build test test-unit test-integration test-e2e +.PHONY: build generate test test-unit test-integration test-e2e + +VERSION ?= $(shell tr -d '\r\n' < ./cmd/server/VERSION) +LDFLAGS ?= -s -w -X main.Version=$(VERSION) build: - go build -o bin/server ./cmd/server + CGO_ENABLED=0 go build -ldflags="$(LDFLAGS)" -trimpath -o bin/server ./cmd/server + +generate: + go generate ./ent + go generate ./cmd/server test: go test ./... @@ -14,4 +21,7 @@ test-integration: go test -tags=integration ./... test-e2e: - go test -tags=e2e ./... + ./scripts/e2e-test.sh + +test-e2e-local: + go test -tags=e2e -v -timeout=300s ./internal/integration/... diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index ce4718bf..bc001693 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -17,7 +17,7 @@ func main() { email := flag.String("email", "", "Admin email to issue a JWT for (defaults to first active admin)") flag.Parse() - cfg, err := config.Load() + cfg, err := config.LoadForBootstrap() if err != nil { log.Fatalf("failed to load config: %v", err) } @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) + authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index f0768f09..32844913 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.70 +0.1.88 \ No newline at end of file diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index f8a7d313..46edcb69 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -8,7 +8,6 @@ import ( "errors" "flag" "log" - "log/slog" "net/http" "os" "os/signal" @@ -19,11 +18,14 @@ import ( _ "github.com/Wei-Shaw/sub2api/ent/runtime" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/setup" "github.com/Wei-Shaw/sub2api/internal/web" "github.com/gin-gonic/gin" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) //go:embed VERSION @@ -38,7 +40,12 @@ var ( ) func init() { - // Read version from embedded VERSION file + // 如果 Version 已通过 ldflags 注入(例如 -X main.Version=...),则不要覆盖。 + if strings.TrimSpace(Version) != "" { + return + } + + // 默认从 embedded VERSION 文件读取版本号(编译期打包进二进制)。 Version = strings.TrimSpace(embeddedVersion) if Version == "" { Version = "0.0.0-dev" @@ -47,22 +54,9 @@ func init() { // initLogger configures the default slog handler based on gin.Mode(). // In non-release mode, Debug level logs are enabled. -func initLogger() { - var level slog.Level - if gin.Mode() == gin.ReleaseMode { - level = slog.LevelInfo - } else { - level = slog.LevelDebug - } - handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: level, - }) - slog.SetDefault(slog.New(handler)) -} - func main() { - // Initialize slog logger based on gin mode - initLogger() + logger.InitBootstrap() + defer logger.Sync() // Parse command line flags setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode") @@ -106,7 +100,7 @@ func runSetupServer() { r := gin.New() r.Use(middleware.Recovery()) r.Use(middleware.CORS(config.CORSConfig{})) - r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy})) + r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}, nil)) // Register setup routes setup.RegisterRoutes(r) @@ -122,16 +116,26 @@ func runSetupServer() { log.Printf("Setup wizard available at http://%s", addr) log.Println("Complete the setup wizard to configure Sub2API") - if err := r.Run(addr); err != nil { + server := &http.Server{ + Addr: addr, + Handler: h2c.NewHandler(r, &http2.Server{}), + ReadHeaderTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } + + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("Failed to start setup server: %v", err) } } func runMainServer() { - cfg, err := config.Load() + cfg, err := config.LoadForBootstrap() if err != nil { log.Fatalf("Failed to load config: %v", err) } + if err := logger.Init(logger.OptionsFromConfig(cfg.Log)); err != nil { + log.Fatalf("Failed to initialize logger: %v", err) + } if cfg.RunMode == config.RunModeSimple { log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED") } diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index d9ff788e..cbf89ba3 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -7,6 +7,7 @@ import ( "context" "log" "net/http" + "sync" "time" "github.com/Wei-Shaw/sub2api/ent" @@ -67,28 +68,36 @@ func provideCleanup( opsAlertEvaluator *service.OpsAlertEvaluatorService, opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, + opsSystemLogSink *service.OpsSystemLogSink, + soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, subscriptionExpiry *service.SubscriptionExpiryService, usageCleanup *service.UsageCleanupService, + idempotencyCleanup *service.IdempotencyCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, + subscriptionService *service.SubscriptionService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, + openAIGateway *service.OpenAIGatewayService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - // Cleanup steps in reverse dependency order - cleanupSteps := []struct { + type cleanupStep struct { name string fn func() error - }{ + } + + // 应用层清理步骤可并行执行,基础设施资源(Redis/Ent)最后按顺序关闭。 + parallelSteps := []cleanupStep{ {"OpsScheduledReportService", func() error { if opsScheduledReport != nil { opsScheduledReport.Stop() @@ -101,6 +110,18 @@ func provideCleanup( } return nil }}, + {"OpsSystemLogSink", func() error { + if opsSystemLogSink != nil { + opsSystemLogSink.Stop() + } + return nil + }}, + {"SoraMediaCleanupService", func() error { + if soraMediaCleanup != nil { + soraMediaCleanup.Stop() + } + return nil + }}, {"OpsAlertEvaluatorService", func() error { if opsAlertEvaluator != nil { opsAlertEvaluator.Stop() @@ -131,6 +152,12 @@ func provideCleanup( } return nil }}, + {"IdempotencyCleanupService", func() error { + if idempotencyCleanup != nil { + idempotencyCleanup.Stop() + } + return nil + }}, {"TokenRefreshService", func() error { tokenRefresh.Stop() return nil @@ -143,6 +170,12 @@ func provideCleanup( subscriptionExpiry.Stop() return nil }}, + {"SubscriptionService", func() error { + if subscriptionService != nil { + subscriptionService.Stop() + } + return nil + }}, {"PricingService", func() error { pricing.Stop() return nil @@ -155,6 +188,12 @@ func provideCleanup( billingCache.Stop() return nil }}, + {"UsageRecordWorkerPool", func() error { + if usageRecordWorkerPool != nil { + usageRecordWorkerPool.Stop() + } + return nil + }}, {"OAuthService", func() error { oauth.Stop() return nil @@ -171,23 +210,60 @@ func provideCleanup( antigravityOAuth.Stop() return nil }}, + {"OpenAIWSPool", func() error { + if openAIGateway != nil { + openAIGateway.CloseOpenAIWSPool() + } + return nil + }}, + } + + infraSteps := []cleanupStep{ {"Redis", func() error { + if rdb == nil { + return nil + } return rdb.Close() }}, {"Ent", func() error { + if entClient == nil { + return nil + } return entClient.Close() }}, } - for _, step := range cleanupSteps { - if err := step.fn(); err != nil { - log.Printf("[Cleanup] %s failed: %v", step.name, err) - // Continue with remaining cleanup steps even if one fails - } else { + runParallel := func(steps []cleanupStep) { + var wg sync.WaitGroup + for i := range steps { + step := steps[i] + wg.Add(1) + go func() { + defer wg.Done() + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + return + } + log.Printf("[Cleanup] %s succeeded", step.name) + }() + } + wg.Wait() + } + + runSequential := func(steps []cleanupStep) { + for i := range steps { + step := steps[i] + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + continue + } log.Printf("[Cleanup] %s succeeded", step.name) } } + runParallel(parallelSteps) + runSequential(infraSteps) + // Check if context timed out select { case <-ctx.Done(): diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ab1831d8..2e9afc26 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -19,6 +19,7 @@ import ( "github.com/redis/go-redis/v9" "log" "net/http" + "sync" "time" ) @@ -47,7 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { redisClient := repository.ProvideRedis(configConfig) refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) - settingService := service.NewSettingService(settingRepository, configConfig) + groupRepository := repository.NewGroupRepository(client, db) + settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) turnstileVerifier := repository.NewTurnstileVerifier() @@ -58,15 +60,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) apiKeyRepository := repository.NewAPIKeyRepository(client) - groupRepository := repository.NewGroupRepository(client, db) userGroupRateRepository := repository.NewUserGroupRateRepository(db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) - authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) - userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) - subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) + subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) + authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) + userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) secretEncryptor, err := repository.NewAESEncryptor(configConfig) @@ -98,11 +99,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) schedulerCache := repository.NewSchedulerCache(redisClient) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) + soraAccountRepository := repository.NewSoraAccountRepository(db) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) - adminUserHandler := admin.NewUserHandler(adminService) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService) + concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) + concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) + adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) @@ -110,7 +114,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() - geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig) + driveClient := repository.NewGeminiDriveClient() + geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig) antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) tempUnschedCache := repository.NewTempUnschedCache(redisClient) @@ -126,23 +131,24 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) gatewayCache := repository.NewGatewayCache(redisClient) - antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) + antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) - concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) - concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) - accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator) + rpmCache := repository.NewRPMCache(redisClient) + accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) + dataManagementService := service.NewDataManagementService() + dataManagementHandler := admin.NewDataManagementHandler(dataManagementService) oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService) proxyHandler := admin.NewProxyHandler(adminService) - adminRedeemHandler := admin.NewRedeemHandler(adminService) + adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService) promoHandler := admin.NewPromoHandler(promoService) opsRepository := repository.NewOpsRepository(db) pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig) @@ -154,18 +160,27 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) + digestSessionStore := service.NewDigestSessionStore() + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) - opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService) + opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) + opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) + soraS3Storage := service.NewSoraS3Storage(settingService) + settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient) + soraGenerationRepository := repository.NewSoraGenerationRepository(db) + soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService) + soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) serviceBuildInfo := provideServiceBuildInfo(buildInfo) updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) - systemHandler := handler.ProvideSystemHandler(updateService) + idempotencyRepository := repository.NewIdempotencyRepository(client, db) + systemOperationLockService := service.ProvideSystemOperationLockService(idempotencyRepository, configConfig) + systemHandler := handler.ProvideSystemHandler(updateService, systemOperationLockService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) usageCleanupRepository := repository.NewUsageCleanupRepository(client, db) usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig) @@ -178,12 +193,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient) errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) + adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler) + usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) + userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) + userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) + soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) + soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) + soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig) + soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService) + soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler) + idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) + idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -194,10 +220,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) + soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService) application := &Application{ Server: httpServer, Cleanup: v, @@ -227,27 +254,35 @@ func provideCleanup( opsAlertEvaluator *service.OpsAlertEvaluatorService, opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, + opsSystemLogSink *service.OpsSystemLogSink, + soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, subscriptionExpiry *service.SubscriptionExpiryService, usageCleanup *service.UsageCleanupService, + idempotencyCleanup *service.IdempotencyCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, + subscriptionService *service.SubscriptionService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, + openAIGateway *service.OpenAIGatewayService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cleanupSteps := []struct { + type cleanupStep struct { name string fn func() error - }{ + } + + parallelSteps := []cleanupStep{ {"OpsScheduledReportService", func() error { if opsScheduledReport != nil { opsScheduledReport.Stop() @@ -260,6 +295,18 @@ func provideCleanup( } return nil }}, + {"OpsSystemLogSink", func() error { + if opsSystemLogSink != nil { + opsSystemLogSink.Stop() + } + return nil + }}, + {"SoraMediaCleanupService", func() error { + if soraMediaCleanup != nil { + soraMediaCleanup.Stop() + } + return nil + }}, {"OpsAlertEvaluatorService", func() error { if opsAlertEvaluator != nil { opsAlertEvaluator.Stop() @@ -290,6 +337,12 @@ func provideCleanup( } return nil }}, + {"IdempotencyCleanupService", func() error { + if idempotencyCleanup != nil { + idempotencyCleanup.Stop() + } + return nil + }}, {"TokenRefreshService", func() error { tokenRefresh.Stop() return nil @@ -302,6 +355,12 @@ func provideCleanup( subscriptionExpiry.Stop() return nil }}, + {"SubscriptionService", func() error { + if subscriptionService != nil { + subscriptionService.Stop() + } + return nil + }}, {"PricingService", func() error { pricing.Stop() return nil @@ -314,6 +373,12 @@ func provideCleanup( billingCache.Stop() return nil }}, + {"UsageRecordWorkerPool", func() error { + if usageRecordWorkerPool != nil { + usageRecordWorkerPool.Stop() + } + return nil + }}, {"OAuthService", func() error { oauth.Stop() return nil @@ -330,23 +395,60 @@ func provideCleanup( antigravityOAuth.Stop() return nil }}, + {"OpenAIWSPool", func() error { + if openAIGateway != nil { + openAIGateway.CloseOpenAIWSPool() + } + return nil + }}, + } + + infraSteps := []cleanupStep{ {"Redis", func() error { + if rdb == nil { + return nil + } return rdb.Close() }}, {"Ent", func() error { + if entClient == nil { + return nil + } return entClient.Close() }}, } - for _, step := range cleanupSteps { - if err := step.fn(); err != nil { - log.Printf("[Cleanup] %s failed: %v", step.name, err) + runParallel := func(steps []cleanupStep) { + var wg sync.WaitGroup + for i := range steps { + step := steps[i] + wg.Add(1) + go func() { + defer wg.Done() + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + return + } + log.Printf("[Cleanup] %s succeeded", step.name) + }() + } + wg.Wait() + } - } else { + runSequential := func(steps []cleanupStep) { + for i := range steps { + step := steps[i] + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + continue + } log.Printf("[Cleanup] %s succeeded", step.name) } } + runParallel(parallelSteps) + runSequential(infraSteps) + select { case <-ctx.Done(): log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds") diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go new file mode 100644 index 00000000..9fb9888d --- /dev/null +++ b/backend/cmd/server/wire_gen_test.go @@ -0,0 +1,81 @@ +package main + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestProvideServiceBuildInfo(t *testing.T) { + in := handler.BuildInfo{ + Version: "v-test", + BuildType: "release", + } + out := provideServiceBuildInfo(in) + require.Equal(t, in.Version, out.Version) + require.Equal(t, in.BuildType, out.BuildType) +} + +func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { + cfg := &config.Config{} + + oauthSvc := service.NewOAuthService(nil, nil) + openAIOAuthSvc := service.NewOpenAIOAuthService(nil, nil) + geminiOAuthSvc := service.NewGeminiOAuthService(nil, nil, nil, nil, cfg) + antigravityOAuthSvc := service.NewAntigravityOAuthService(nil) + + tokenRefreshSvc := service.NewTokenRefreshService( + nil, + oauthSvc, + openAIOAuthSvc, + geminiOAuthSvc, + antigravityOAuthSvc, + nil, + nil, + cfg, + ) + accountExpirySvc := service.NewAccountExpiryService(nil, time.Second) + subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second) + pricingSvc := service.NewPricingService(cfg, nil) + emailQueueSvc := service.NewEmailQueueService(nil, 1) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg) + idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg) + schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg) + opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil) + + cleanup := provideCleanup( + nil, // entClient + nil, // redis + &service.OpsMetricsCollector{}, + &service.OpsAggregationService{}, + &service.OpsAlertEvaluatorService{}, + &service.OpsCleanupService{}, + &service.OpsScheduledReportService{}, + opsSystemLogSinkSvc, + &service.SoraMediaCleanupService{}, + schedulerSnapshotSvc, + tokenRefreshSvc, + accountExpirySvc, + subscriptionExpirySvc, + &service.UsageCleanupService{}, + idempotencyCleanupSvc, + pricingSvc, + emailQueueSvc, + billingCacheSvc, + &service.UsageRecordWorkerPool{}, + &service.SubscriptionService{}, + oauthSvc, + openAIOAuthSvc, + geminiOAuthSvc, + antigravityOAuthSvc, + nil, // openAIGateway + ) + + require.NotPanics(t, func() { + cleanup() + }) +} diff --git a/backend/ent/account.go b/backend/ent/account.go index 038aa7e5..c77002b3 100644 --- a/backend/ent/account.go +++ b/backend/ent/account.go @@ -63,6 +63,10 @@ type Account struct { RateLimitResetAt *time.Time `json:"rate_limit_reset_at,omitempty"` // OverloadUntil holds the value of the "overload_until" field. OverloadUntil *time.Time `json:"overload_until,omitempty"` + // TempUnschedulableUntil holds the value of the "temp_unschedulable_until" field. + TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"` + // TempUnschedulableReason holds the value of the "temp_unschedulable_reason" field. + TempUnschedulableReason *string `json:"temp_unschedulable_reason,omitempty"` // SessionWindowStart holds the value of the "session_window_start" field. SessionWindowStart *time.Time `json:"session_window_start,omitempty"` // SessionWindowEnd holds the value of the "session_window_end" field. @@ -141,9 +145,9 @@ func (*Account) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority: values[i] = new(sql.NullInt64) - case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus: + case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus: values[i] = new(sql.NullString) - case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd: + case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldTempUnschedulableUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -311,6 +315,20 @@ func (_m *Account) assignValues(columns []string, values []any) error { _m.OverloadUntil = new(time.Time) *_m.OverloadUntil = value.Time } + case account.FieldTempUnschedulableUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field temp_unschedulable_until", values[i]) + } else if value.Valid { + _m.TempUnschedulableUntil = new(time.Time) + *_m.TempUnschedulableUntil = value.Time + } + case account.FieldTempUnschedulableReason: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field temp_unschedulable_reason", values[i]) + } else if value.Valid { + _m.TempUnschedulableReason = new(string) + *_m.TempUnschedulableReason = value.String + } case account.FieldSessionWindowStart: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field session_window_start", values[i]) @@ -472,6 +490,16 @@ func (_m *Account) String() string { builder.WriteString(v.Format(time.ANSIC)) } builder.WriteString(", ") + if v := _m.TempUnschedulableUntil; v != nil { + builder.WriteString("temp_unschedulable_until=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.TempUnschedulableReason; v != nil { + builder.WriteString("temp_unschedulable_reason=") + builder.WriteString(*v) + } + builder.WriteString(", ") if v := _m.SessionWindowStart; v != nil { builder.WriteString("session_window_start=") builder.WriteString(v.Format(time.ANSIC)) diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go index 73c0e8c2..1fc34620 100644 --- a/backend/ent/account/account.go +++ b/backend/ent/account/account.go @@ -59,6 +59,10 @@ const ( FieldRateLimitResetAt = "rate_limit_reset_at" // FieldOverloadUntil holds the string denoting the overload_until field in the database. FieldOverloadUntil = "overload_until" + // FieldTempUnschedulableUntil holds the string denoting the temp_unschedulable_until field in the database. + FieldTempUnschedulableUntil = "temp_unschedulable_until" + // FieldTempUnschedulableReason holds the string denoting the temp_unschedulable_reason field in the database. + FieldTempUnschedulableReason = "temp_unschedulable_reason" // FieldSessionWindowStart holds the string denoting the session_window_start field in the database. FieldSessionWindowStart = "session_window_start" // FieldSessionWindowEnd holds the string denoting the session_window_end field in the database. @@ -128,6 +132,8 @@ var Columns = []string{ FieldRateLimitedAt, FieldRateLimitResetAt, FieldOverloadUntil, + FieldTempUnschedulableUntil, + FieldTempUnschedulableReason, FieldSessionWindowStart, FieldSessionWindowEnd, FieldSessionWindowStatus, @@ -299,6 +305,16 @@ func ByOverloadUntil(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldOverloadUntil, opts...).ToFunc() } +// ByTempUnschedulableUntil orders the results by the temp_unschedulable_until field. +func ByTempUnschedulableUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTempUnschedulableUntil, opts...).ToFunc() +} + +// ByTempUnschedulableReason orders the results by the temp_unschedulable_reason field. +func ByTempUnschedulableReason(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTempUnschedulableReason, opts...).ToFunc() +} + // BySessionWindowStart orders the results by the session_window_start field. func BySessionWindowStart(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldSessionWindowStart, opts...).ToFunc() diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go index dea1127a..54db1dcb 100644 --- a/backend/ent/account/where.go +++ b/backend/ent/account/where.go @@ -155,6 +155,16 @@ func OverloadUntil(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldOverloadUntil, v)) } +// TempUnschedulableUntil applies equality check predicate on the "temp_unschedulable_until" field. It's identical to TempUnschedulableUntilEQ. +func TempUnschedulableUntil(v time.Time) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableReason applies equality check predicate on the "temp_unschedulable_reason" field. It's identical to TempUnschedulableReasonEQ. +func TempUnschedulableReason(v string) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableReason, v)) +} + // SessionWindowStart applies equality check predicate on the "session_window_start" field. It's identical to SessionWindowStartEQ. func SessionWindowStart(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v)) @@ -1130,6 +1140,131 @@ func OverloadUntilNotNil() predicate.Account { return predicate.Account(sql.FieldNotNull(FieldOverloadUntil)) } +// TempUnschedulableUntilEQ applies the EQ predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilEQ(v time.Time) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilNEQ applies the NEQ predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilNEQ(v time.Time) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilIn applies the In predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilIn(vs ...time.Time) predicate.Account { + return predicate.Account(sql.FieldIn(FieldTempUnschedulableUntil, vs...)) +} + +// TempUnschedulableUntilNotIn applies the NotIn predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilNotIn(vs ...time.Time) predicate.Account { + return predicate.Account(sql.FieldNotIn(FieldTempUnschedulableUntil, vs...)) +} + +// TempUnschedulableUntilGT applies the GT predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilGT(v time.Time) predicate.Account { + return predicate.Account(sql.FieldGT(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilGTE applies the GTE predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilGTE(v time.Time) predicate.Account { + return predicate.Account(sql.FieldGTE(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilLT applies the LT predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilLT(v time.Time) predicate.Account { + return predicate.Account(sql.FieldLT(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilLTE applies the LTE predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilLTE(v time.Time) predicate.Account { + return predicate.Account(sql.FieldLTE(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilIsNil applies the IsNil predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilIsNil() predicate.Account { + return predicate.Account(sql.FieldIsNull(FieldTempUnschedulableUntil)) +} + +// TempUnschedulableUntilNotNil applies the NotNil predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilNotNil() predicate.Account { + return predicate.Account(sql.FieldNotNull(FieldTempUnschedulableUntil)) +} + +// TempUnschedulableReasonEQ applies the EQ predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonEQ(v string) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonNEQ applies the NEQ predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonNEQ(v string) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonIn applies the In predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonIn(vs ...string) predicate.Account { + return predicate.Account(sql.FieldIn(FieldTempUnschedulableReason, vs...)) +} + +// TempUnschedulableReasonNotIn applies the NotIn predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonNotIn(vs ...string) predicate.Account { + return predicate.Account(sql.FieldNotIn(FieldTempUnschedulableReason, vs...)) +} + +// TempUnschedulableReasonGT applies the GT predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonGT(v string) predicate.Account { + return predicate.Account(sql.FieldGT(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonGTE applies the GTE predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonGTE(v string) predicate.Account { + return predicate.Account(sql.FieldGTE(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonLT applies the LT predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonLT(v string) predicate.Account { + return predicate.Account(sql.FieldLT(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonLTE applies the LTE predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonLTE(v string) predicate.Account { + return predicate.Account(sql.FieldLTE(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonContains applies the Contains predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonContains(v string) predicate.Account { + return predicate.Account(sql.FieldContains(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonHasPrefix applies the HasPrefix predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonHasPrefix(v string) predicate.Account { + return predicate.Account(sql.FieldHasPrefix(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonHasSuffix applies the HasSuffix predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonHasSuffix(v string) predicate.Account { + return predicate.Account(sql.FieldHasSuffix(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonIsNil applies the IsNil predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonIsNil() predicate.Account { + return predicate.Account(sql.FieldIsNull(FieldTempUnschedulableReason)) +} + +// TempUnschedulableReasonNotNil applies the NotNil predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonNotNil() predicate.Account { + return predicate.Account(sql.FieldNotNull(FieldTempUnschedulableReason)) +} + +// TempUnschedulableReasonEqualFold applies the EqualFold predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonEqualFold(v string) predicate.Account { + return predicate.Account(sql.FieldEqualFold(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonContainsFold applies the ContainsFold predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonContainsFold(v string) predicate.Account { + return predicate.Account(sql.FieldContainsFold(FieldTempUnschedulableReason, v)) +} + // SessionWindowStartEQ applies the EQ predicate on the "session_window_start" field. func SessionWindowStartEQ(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v)) diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go index 42a561cf..963ffee8 100644 --- a/backend/ent/account_create.go +++ b/backend/ent/account_create.go @@ -293,6 +293,34 @@ func (_c *AccountCreate) SetNillableOverloadUntil(v *time.Time) *AccountCreate { return _c } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (_c *AccountCreate) SetTempUnschedulableUntil(v time.Time) *AccountCreate { + _c.mutation.SetTempUnschedulableUntil(v) + return _c +} + +// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil. +func (_c *AccountCreate) SetNillableTempUnschedulableUntil(v *time.Time) *AccountCreate { + if v != nil { + _c.SetTempUnschedulableUntil(*v) + } + return _c +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (_c *AccountCreate) SetTempUnschedulableReason(v string) *AccountCreate { + _c.mutation.SetTempUnschedulableReason(v) + return _c +} + +// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil. +func (_c *AccountCreate) SetNillableTempUnschedulableReason(v *string) *AccountCreate { + if v != nil { + _c.SetTempUnschedulableReason(*v) + } + return _c +} + // SetSessionWindowStart sets the "session_window_start" field. func (_c *AccountCreate) SetSessionWindowStart(v time.Time) *AccountCreate { _c.mutation.SetSessionWindowStart(v) @@ -639,6 +667,14 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { _spec.SetField(account.FieldOverloadUntil, field.TypeTime, value) _node.OverloadUntil = &value } + if value, ok := _c.mutation.TempUnschedulableUntil(); ok { + _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value) + _node.TempUnschedulableUntil = &value + } + if value, ok := _c.mutation.TempUnschedulableReason(); ok { + _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value) + _node.TempUnschedulableReason = &value + } if value, ok := _c.mutation.SessionWindowStart(); ok { _spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value) _node.SessionWindowStart = &value @@ -1080,6 +1116,42 @@ func (u *AccountUpsert) ClearOverloadUntil() *AccountUpsert { return u } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (u *AccountUpsert) SetTempUnschedulableUntil(v time.Time) *AccountUpsert { + u.Set(account.FieldTempUnschedulableUntil, v) + return u +} + +// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create. +func (u *AccountUpsert) UpdateTempUnschedulableUntil() *AccountUpsert { + u.SetExcluded(account.FieldTempUnschedulableUntil) + return u +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (u *AccountUpsert) ClearTempUnschedulableUntil() *AccountUpsert { + u.SetNull(account.FieldTempUnschedulableUntil) + return u +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (u *AccountUpsert) SetTempUnschedulableReason(v string) *AccountUpsert { + u.Set(account.FieldTempUnschedulableReason, v) + return u +} + +// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create. +func (u *AccountUpsert) UpdateTempUnschedulableReason() *AccountUpsert { + u.SetExcluded(account.FieldTempUnschedulableReason) + return u +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (u *AccountUpsert) ClearTempUnschedulableReason() *AccountUpsert { + u.SetNull(account.FieldTempUnschedulableReason) + return u +} + // SetSessionWindowStart sets the "session_window_start" field. func (u *AccountUpsert) SetSessionWindowStart(v time.Time) *AccountUpsert { u.Set(account.FieldSessionWindowStart, v) @@ -1557,6 +1629,48 @@ func (u *AccountUpsertOne) ClearOverloadUntil() *AccountUpsertOne { }) } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (u *AccountUpsertOne) SetTempUnschedulableUntil(v time.Time) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableUntil(v) + }) +} + +// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateTempUnschedulableUntil() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableUntil() + }) +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (u *AccountUpsertOne) ClearTempUnschedulableUntil() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableUntil() + }) +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (u *AccountUpsertOne) SetTempUnschedulableReason(v string) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableReason(v) + }) +} + +// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateTempUnschedulableReason() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableReason() + }) +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (u *AccountUpsertOne) ClearTempUnschedulableReason() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableReason() + }) +} + // SetSessionWindowStart sets the "session_window_start" field. func (u *AccountUpsertOne) SetSessionWindowStart(v time.Time) *AccountUpsertOne { return u.Update(func(s *AccountUpsert) { @@ -2209,6 +2323,48 @@ func (u *AccountUpsertBulk) ClearOverloadUntil() *AccountUpsertBulk { }) } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (u *AccountUpsertBulk) SetTempUnschedulableUntil(v time.Time) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableUntil(v) + }) +} + +// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateTempUnschedulableUntil() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableUntil() + }) +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (u *AccountUpsertBulk) ClearTempUnschedulableUntil() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableUntil() + }) +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (u *AccountUpsertBulk) SetTempUnschedulableReason(v string) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableReason(v) + }) +} + +// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateTempUnschedulableReason() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableReason() + }) +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (u *AccountUpsertBulk) ClearTempUnschedulableReason() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableReason() + }) +} + // SetSessionWindowStart sets the "session_window_start" field. func (u *AccountUpsertBulk) SetSessionWindowStart(v time.Time) *AccountUpsertBulk { return u.Update(func(s *AccountUpsert) { diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go index 63fab096..875888e0 100644 --- a/backend/ent/account_update.go +++ b/backend/ent/account_update.go @@ -376,6 +376,46 @@ func (_u *AccountUpdate) ClearOverloadUntil() *AccountUpdate { return _u } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (_u *AccountUpdate) SetTempUnschedulableUntil(v time.Time) *AccountUpdate { + _u.mutation.SetTempUnschedulableUntil(v) + return _u +} + +// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableTempUnschedulableUntil(v *time.Time) *AccountUpdate { + if v != nil { + _u.SetTempUnschedulableUntil(*v) + } + return _u +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (_u *AccountUpdate) ClearTempUnschedulableUntil() *AccountUpdate { + _u.mutation.ClearTempUnschedulableUntil() + return _u +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (_u *AccountUpdate) SetTempUnschedulableReason(v string) *AccountUpdate { + _u.mutation.SetTempUnschedulableReason(v) + return _u +} + +// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableTempUnschedulableReason(v *string) *AccountUpdate { + if v != nil { + _u.SetTempUnschedulableReason(*v) + } + return _u +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (_u *AccountUpdate) ClearTempUnschedulableReason() *AccountUpdate { + _u.mutation.ClearTempUnschedulableReason() + return _u +} + // SetSessionWindowStart sets the "session_window_start" field. func (_u *AccountUpdate) SetSessionWindowStart(v time.Time) *AccountUpdate { _u.mutation.SetSessionWindowStart(v) @@ -701,6 +741,18 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.OverloadUntilCleared() { _spec.ClearField(account.FieldOverloadUntil, field.TypeTime) } + if value, ok := _u.mutation.TempUnschedulableUntil(); ok { + _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value) + } + if _u.mutation.TempUnschedulableUntilCleared() { + _spec.ClearField(account.FieldTempUnschedulableUntil, field.TypeTime) + } + if value, ok := _u.mutation.TempUnschedulableReason(); ok { + _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value) + } + if _u.mutation.TempUnschedulableReasonCleared() { + _spec.ClearField(account.FieldTempUnschedulableReason, field.TypeString) + } if value, ok := _u.mutation.SessionWindowStart(); ok { _spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value) } @@ -1215,6 +1267,46 @@ func (_u *AccountUpdateOne) ClearOverloadUntil() *AccountUpdateOne { return _u } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (_u *AccountUpdateOne) SetTempUnschedulableUntil(v time.Time) *AccountUpdateOne { + _u.mutation.SetTempUnschedulableUntil(v) + return _u +} + +// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableTempUnschedulableUntil(v *time.Time) *AccountUpdateOne { + if v != nil { + _u.SetTempUnschedulableUntil(*v) + } + return _u +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (_u *AccountUpdateOne) ClearTempUnschedulableUntil() *AccountUpdateOne { + _u.mutation.ClearTempUnschedulableUntil() + return _u +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (_u *AccountUpdateOne) SetTempUnschedulableReason(v string) *AccountUpdateOne { + _u.mutation.SetTempUnschedulableReason(v) + return _u +} + +// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableTempUnschedulableReason(v *string) *AccountUpdateOne { + if v != nil { + _u.SetTempUnschedulableReason(*v) + } + return _u +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (_u *AccountUpdateOne) ClearTempUnschedulableReason() *AccountUpdateOne { + _u.mutation.ClearTempUnschedulableReason() + return _u +} + // SetSessionWindowStart sets the "session_window_start" field. func (_u *AccountUpdateOne) SetSessionWindowStart(v time.Time) *AccountUpdateOne { _u.mutation.SetSessionWindowStart(v) @@ -1570,6 +1662,18 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er if _u.mutation.OverloadUntilCleared() { _spec.ClearField(account.FieldOverloadUntil, field.TypeTime) } + if value, ok := _u.mutation.TempUnschedulableUntil(); ok { + _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value) + } + if _u.mutation.TempUnschedulableUntilCleared() { + _spec.ClearField(account.FieldTempUnschedulableUntil, field.TypeTime) + } + if value, ok := _u.mutation.TempUnschedulableReason(); ok { + _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value) + } + if _u.mutation.TempUnschedulableReasonCleared() { + _spec.ClearField(account.FieldTempUnschedulableReason, field.TypeString) + } if value, ok := _u.mutation.SessionWindowStart(); ok { _spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value) } diff --git a/backend/ent/apikey.go b/backend/ent/apikey.go index 91d71964..760851c8 100644 --- a/backend/ent/apikey.go +++ b/backend/ent/apikey.go @@ -36,6 +36,8 @@ type APIKey struct { GroupID *int64 `json:"group_id,omitempty"` // Status holds the value of the "status" field. Status string `json:"status,omitempty"` + // Last usage time of this API key + LastUsedAt *time.Time `json:"last_used_at,omitempty"` // Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"] IPWhitelist []string `json:"ip_whitelist,omitempty"` // Blocked IPs/CIDRs @@ -109,7 +111,7 @@ func (*APIKey) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullInt64) case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: values[i] = new(sql.NullString) - case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldExpiresAt: + case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -182,6 +184,13 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Status = value.String } + case apikey.FieldLastUsedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_used_at", values[i]) + } else if value.Valid { + _m.LastUsedAt = new(time.Time) + *_m.LastUsedAt = value.Time + } case apikey.FieldIPWhitelist: if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i]) @@ -296,6 +305,11 @@ func (_m *APIKey) String() string { builder.WriteString("status=") builder.WriteString(_m.Status) builder.WriteString(", ") + if v := _m.LastUsedAt; v != nil { + builder.WriteString("last_used_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") builder.WriteString("ip_whitelist=") builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist)) builder.WriteString(", ") diff --git a/backend/ent/apikey/apikey.go b/backend/ent/apikey/apikey.go index ac2a6008..6abea56b 100644 --- a/backend/ent/apikey/apikey.go +++ b/backend/ent/apikey/apikey.go @@ -31,6 +31,8 @@ const ( FieldGroupID = "group_id" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" + // FieldLastUsedAt holds the string denoting the last_used_at field in the database. + FieldLastUsedAt = "last_used_at" // FieldIPWhitelist holds the string denoting the ip_whitelist field in the database. FieldIPWhitelist = "ip_whitelist" // FieldIPBlacklist holds the string denoting the ip_blacklist field in the database. @@ -83,6 +85,7 @@ var Columns = []string{ FieldName, FieldGroupID, FieldStatus, + FieldLastUsedAt, FieldIPWhitelist, FieldIPBlacklist, FieldQuota, @@ -176,6 +179,11 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() } +// ByLastUsedAt orders the results by the last_used_at field. +func ByLastUsedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastUsedAt, opts...).ToFunc() +} + // ByQuota orders the results by the quota field. func ByQuota(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldQuota, opts...).ToFunc() diff --git a/backend/ent/apikey/where.go b/backend/ent/apikey/where.go index f54f44b7..c1900ee1 100644 --- a/backend/ent/apikey/where.go +++ b/backend/ent/apikey/where.go @@ -95,6 +95,11 @@ func Status(v string) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldStatus, v)) } +// LastUsedAt applies equality check predicate on the "last_used_at" field. It's identical to LastUsedAtEQ. +func LastUsedAt(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldLastUsedAt, v)) +} + // Quota applies equality check predicate on the "quota" field. It's identical to QuotaEQ. func Quota(v float64) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldQuota, v)) @@ -485,6 +490,56 @@ func StatusContainsFold(v string) predicate.APIKey { return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v)) } +// LastUsedAtEQ applies the EQ predicate on the "last_used_at" field. +func LastUsedAtEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldLastUsedAt, v)) +} + +// LastUsedAtNEQ applies the NEQ predicate on the "last_used_at" field. +func LastUsedAtNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldLastUsedAt, v)) +} + +// LastUsedAtIn applies the In predicate on the "last_used_at" field. +func LastUsedAtIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldLastUsedAt, vs...)) +} + +// LastUsedAtNotIn applies the NotIn predicate on the "last_used_at" field. +func LastUsedAtNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldLastUsedAt, vs...)) +} + +// LastUsedAtGT applies the GT predicate on the "last_used_at" field. +func LastUsedAtGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldLastUsedAt, v)) +} + +// LastUsedAtGTE applies the GTE predicate on the "last_used_at" field. +func LastUsedAtGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldLastUsedAt, v)) +} + +// LastUsedAtLT applies the LT predicate on the "last_used_at" field. +func LastUsedAtLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldLastUsedAt, v)) +} + +// LastUsedAtLTE applies the LTE predicate on the "last_used_at" field. +func LastUsedAtLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldLastUsedAt, v)) +} + +// LastUsedAtIsNil applies the IsNil predicate on the "last_used_at" field. +func LastUsedAtIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldLastUsedAt)) +} + +// LastUsedAtNotNil applies the NotNil predicate on the "last_used_at" field. +func LastUsedAtNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldLastUsedAt)) +} + // IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field. func IPWhitelistIsNil() predicate.APIKey { return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist)) diff --git a/backend/ent/apikey_create.go b/backend/ent/apikey_create.go index 71540975..bc506585 100644 --- a/backend/ent/apikey_create.go +++ b/backend/ent/apikey_create.go @@ -113,6 +113,20 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate { return _c } +// SetLastUsedAt sets the "last_used_at" field. +func (_c *APIKeyCreate) SetLastUsedAt(v time.Time) *APIKeyCreate { + _c.mutation.SetLastUsedAt(v) + return _c +} + +// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableLastUsedAt(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetLastUsedAt(*v) + } + return _c +} + // SetIPWhitelist sets the "ip_whitelist" field. func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate { _c.mutation.SetIPWhitelist(v) @@ -353,6 +367,10 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { _spec.SetField(apikey.FieldStatus, field.TypeString, value) _node.Status = value } + if value, ok := _c.mutation.LastUsedAt(); ok { + _spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value) + _node.LastUsedAt = &value + } if value, ok := _c.mutation.IPWhitelist(); ok { _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) _node.IPWhitelist = value @@ -571,6 +589,24 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert { return u } +// SetLastUsedAt sets the "last_used_at" field. +func (u *APIKeyUpsert) SetLastUsedAt(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldLastUsedAt, v) + return u +} + +// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateLastUsedAt() *APIKeyUpsert { + u.SetExcluded(apikey.FieldLastUsedAt) + return u +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (u *APIKeyUpsert) ClearLastUsedAt() *APIKeyUpsert { + u.SetNull(apikey.FieldLastUsedAt) + return u +} + // SetIPWhitelist sets the "ip_whitelist" field. func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert { u.Set(apikey.FieldIPWhitelist, v) @@ -818,6 +854,27 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne { }) } +// SetLastUsedAt sets the "last_used_at" field. +func (u *APIKeyUpsertOne) SetLastUsedAt(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetLastUsedAt(v) + }) +} + +// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateLastUsedAt() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateLastUsedAt() + }) +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (u *APIKeyUpsertOne) ClearLastUsedAt() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearLastUsedAt() + }) +} + // SetIPWhitelist sets the "ip_whitelist" field. func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne { return u.Update(func(s *APIKeyUpsert) { @@ -1246,6 +1303,27 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk { }) } +// SetLastUsedAt sets the "last_used_at" field. +func (u *APIKeyUpsertBulk) SetLastUsedAt(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetLastUsedAt(v) + }) +} + +// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateLastUsedAt() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateLastUsedAt() + }) +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (u *APIKeyUpsertBulk) ClearLastUsedAt() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearLastUsedAt() + }) +} + // SetIPWhitelist sets the "ip_whitelist" field. func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk { return u.Update(func(s *APIKeyUpsert) { diff --git a/backend/ent/apikey_update.go b/backend/ent/apikey_update.go index b4ff230b..6ca01854 100644 --- a/backend/ent/apikey_update.go +++ b/backend/ent/apikey_update.go @@ -134,6 +134,26 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate { return _u } +// SetLastUsedAt sets the "last_used_at" field. +func (_u *APIKeyUpdate) SetLastUsedAt(v time.Time) *APIKeyUpdate { + _u.mutation.SetLastUsedAt(v) + return _u +} + +// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableLastUsedAt(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetLastUsedAt(*v) + } + return _u +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (_u *APIKeyUpdate) ClearLastUsedAt() *APIKeyUpdate { + _u.mutation.ClearLastUsedAt() + return _u +} + // SetIPWhitelist sets the "ip_whitelist" field. func (_u *APIKeyUpdate) SetIPWhitelist(v []string) *APIKeyUpdate { _u.mutation.SetIPWhitelist(v) @@ -390,6 +410,12 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Status(); ok { _spec.SetField(apikey.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.LastUsedAt(); ok { + _spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value) + } + if _u.mutation.LastUsedAtCleared() { + _spec.ClearField(apikey.FieldLastUsedAt, field.TypeTime) + } if value, ok := _u.mutation.IPWhitelist(); ok { _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) } @@ -655,6 +681,26 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne { return _u } +// SetLastUsedAt sets the "last_used_at" field. +func (_u *APIKeyUpdateOne) SetLastUsedAt(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetLastUsedAt(v) + return _u +} + +// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableLastUsedAt(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetLastUsedAt(*v) + } + return _u +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (_u *APIKeyUpdateOne) ClearLastUsedAt() *APIKeyUpdateOne { + _u.mutation.ClearLastUsedAt() + return _u +} + // SetIPWhitelist sets the "ip_whitelist" field. func (_u *APIKeyUpdateOne) SetIPWhitelist(v []string) *APIKeyUpdateOne { _u.mutation.SetIPWhitelist(v) @@ -941,6 +987,12 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro if value, ok := _u.mutation.Status(); ok { _spec.SetField(apikey.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.LastUsedAt(); ok { + _spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value) + } + if _u.mutation.LastUsedAtCleared() { + _spec.ClearField(apikey.FieldLastUsedAt, field.TypeTime) + } if value, ok := _u.mutation.IPWhitelist(); ok { _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) } diff --git a/backend/ent/client.go b/backend/ent/client.go index a791c081..7ebbaa32 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -22,10 +22,12 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -57,6 +59,8 @@ type Client struct { ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient + // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. + IdempotencyRecord *IdempotencyRecordClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -65,6 +69,8 @@ type Client struct { Proxy *ProxyClient // RedeemCode is the client for interacting with the RedeemCode builders. RedeemCode *RedeemCodeClient + // SecuritySecret is the client for interacting with the SecuritySecret builders. + SecuritySecret *SecuritySecretClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. @@ -99,10 +105,12 @@ func (c *Client) init() { c.AnnouncementRead = NewAnnouncementReadClient(c.config) c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config) c.Group = NewGroupClient(c.config) + c.IdempotencyRecord = NewIdempotencyRecordClient(c.config) c.PromoCode = NewPromoCodeClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) c.Proxy = NewProxyClient(c.config) c.RedeemCode = NewRedeemCodeClient(c.config) + c.SecuritySecret = NewSecuritySecretClient(c.config) c.Setting = NewSettingClient(c.config) c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config) c.UsageLog = NewUsageLogClient(c.config) @@ -210,10 +218,12 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { AnnouncementRead: NewAnnouncementReadClient(cfg), ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), Setting: NewSettingClient(cfg), UsageCleanupTask: NewUsageCleanupTaskClient(cfg), UsageLog: NewUsageLogClient(cfg), @@ -248,10 +258,12 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) AnnouncementRead: NewAnnouncementReadClient(cfg), ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), Setting: NewSettingClient(cfg), UsageCleanupTask: NewUsageCleanupTaskClient(cfg), UsageLog: NewUsageLogClient(cfg), @@ -290,10 +302,10 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, - c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, + c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) } @@ -304,10 +316,10 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, - c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, + c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) } @@ -330,6 +342,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.ErrorPassthroughRule.mutate(ctx, m) case *GroupMutation: return c.Group.mutate(ctx, m) + case *IdempotencyRecordMutation: + return c.IdempotencyRecord.mutate(ctx, m) case *PromoCodeMutation: return c.PromoCode.mutate(ctx, m) case *PromoCodeUsageMutation: @@ -338,6 +352,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Proxy.mutate(ctx, m) case *RedeemCodeMutation: return c.RedeemCode.mutate(ctx, m) + case *SecuritySecretMutation: + return c.SecuritySecret.mutate(ctx, m) case *SettingMutation: return c.Setting.mutate(ctx, m) case *UsageCleanupTaskMutation: @@ -1567,6 +1583,139 @@ func (c *GroupClient) mutate(ctx context.Context, m *GroupMutation) (Value, erro } } +// IdempotencyRecordClient is a client for the IdempotencyRecord schema. +type IdempotencyRecordClient struct { + config +} + +// NewIdempotencyRecordClient returns a client for the IdempotencyRecord from the given config. +func NewIdempotencyRecordClient(c config) *IdempotencyRecordClient { + return &IdempotencyRecordClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `idempotencyrecord.Hooks(f(g(h())))`. +func (c *IdempotencyRecordClient) Use(hooks ...Hook) { + c.hooks.IdempotencyRecord = append(c.hooks.IdempotencyRecord, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `idempotencyrecord.Intercept(f(g(h())))`. +func (c *IdempotencyRecordClient) Intercept(interceptors ...Interceptor) { + c.inters.IdempotencyRecord = append(c.inters.IdempotencyRecord, interceptors...) +} + +// Create returns a builder for creating a IdempotencyRecord entity. +func (c *IdempotencyRecordClient) Create() *IdempotencyRecordCreate { + mutation := newIdempotencyRecordMutation(c.config, OpCreate) + return &IdempotencyRecordCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of IdempotencyRecord entities. +func (c *IdempotencyRecordClient) CreateBulk(builders ...*IdempotencyRecordCreate) *IdempotencyRecordCreateBulk { + return &IdempotencyRecordCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *IdempotencyRecordClient) MapCreateBulk(slice any, setFunc func(*IdempotencyRecordCreate, int)) *IdempotencyRecordCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &IdempotencyRecordCreateBulk{err: fmt.Errorf("calling to IdempotencyRecordClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*IdempotencyRecordCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &IdempotencyRecordCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for IdempotencyRecord. +func (c *IdempotencyRecordClient) Update() *IdempotencyRecordUpdate { + mutation := newIdempotencyRecordMutation(c.config, OpUpdate) + return &IdempotencyRecordUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *IdempotencyRecordClient) UpdateOne(_m *IdempotencyRecord) *IdempotencyRecordUpdateOne { + mutation := newIdempotencyRecordMutation(c.config, OpUpdateOne, withIdempotencyRecord(_m)) + return &IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *IdempotencyRecordClient) UpdateOneID(id int64) *IdempotencyRecordUpdateOne { + mutation := newIdempotencyRecordMutation(c.config, OpUpdateOne, withIdempotencyRecordID(id)) + return &IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for IdempotencyRecord. +func (c *IdempotencyRecordClient) Delete() *IdempotencyRecordDelete { + mutation := newIdempotencyRecordMutation(c.config, OpDelete) + return &IdempotencyRecordDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *IdempotencyRecordClient) DeleteOne(_m *IdempotencyRecord) *IdempotencyRecordDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *IdempotencyRecordClient) DeleteOneID(id int64) *IdempotencyRecordDeleteOne { + builder := c.Delete().Where(idempotencyrecord.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &IdempotencyRecordDeleteOne{builder} +} + +// Query returns a query builder for IdempotencyRecord. +func (c *IdempotencyRecordClient) Query() *IdempotencyRecordQuery { + return &IdempotencyRecordQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeIdempotencyRecord}, + inters: c.Interceptors(), + } +} + +// Get returns a IdempotencyRecord entity by its id. +func (c *IdempotencyRecordClient) Get(ctx context.Context, id int64) (*IdempotencyRecord, error) { + return c.Query().Where(idempotencyrecord.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *IdempotencyRecordClient) GetX(ctx context.Context, id int64) *IdempotencyRecord { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *IdempotencyRecordClient) Hooks() []Hook { + return c.hooks.IdempotencyRecord +} + +// Interceptors returns the client interceptors. +func (c *IdempotencyRecordClient) Interceptors() []Interceptor { + return c.inters.IdempotencyRecord +} + +func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyRecordMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&IdempotencyRecordCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&IdempotencyRecordUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&IdempotencyRecordDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown IdempotencyRecord mutation op: %q", m.Op()) + } +} + // PromoCodeClient is a client for the PromoCode schema. type PromoCodeClient struct { config @@ -2197,6 +2346,139 @@ func (c *RedeemCodeClient) mutate(ctx context.Context, m *RedeemCodeMutation) (V } } +// SecuritySecretClient is a client for the SecuritySecret schema. +type SecuritySecretClient struct { + config +} + +// NewSecuritySecretClient returns a client for the SecuritySecret from the given config. +func NewSecuritySecretClient(c config) *SecuritySecretClient { + return &SecuritySecretClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `securitysecret.Hooks(f(g(h())))`. +func (c *SecuritySecretClient) Use(hooks ...Hook) { + c.hooks.SecuritySecret = append(c.hooks.SecuritySecret, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `securitysecret.Intercept(f(g(h())))`. +func (c *SecuritySecretClient) Intercept(interceptors ...Interceptor) { + c.inters.SecuritySecret = append(c.inters.SecuritySecret, interceptors...) +} + +// Create returns a builder for creating a SecuritySecret entity. +func (c *SecuritySecretClient) Create() *SecuritySecretCreate { + mutation := newSecuritySecretMutation(c.config, OpCreate) + return &SecuritySecretCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of SecuritySecret entities. +func (c *SecuritySecretClient) CreateBulk(builders ...*SecuritySecretCreate) *SecuritySecretCreateBulk { + return &SecuritySecretCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *SecuritySecretClient) MapCreateBulk(slice any, setFunc func(*SecuritySecretCreate, int)) *SecuritySecretCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &SecuritySecretCreateBulk{err: fmt.Errorf("calling to SecuritySecretClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*SecuritySecretCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &SecuritySecretCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for SecuritySecret. +func (c *SecuritySecretClient) Update() *SecuritySecretUpdate { + mutation := newSecuritySecretMutation(c.config, OpUpdate) + return &SecuritySecretUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *SecuritySecretClient) UpdateOne(_m *SecuritySecret) *SecuritySecretUpdateOne { + mutation := newSecuritySecretMutation(c.config, OpUpdateOne, withSecuritySecret(_m)) + return &SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *SecuritySecretClient) UpdateOneID(id int64) *SecuritySecretUpdateOne { + mutation := newSecuritySecretMutation(c.config, OpUpdateOne, withSecuritySecretID(id)) + return &SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for SecuritySecret. +func (c *SecuritySecretClient) Delete() *SecuritySecretDelete { + mutation := newSecuritySecretMutation(c.config, OpDelete) + return &SecuritySecretDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *SecuritySecretClient) DeleteOne(_m *SecuritySecret) *SecuritySecretDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *SecuritySecretClient) DeleteOneID(id int64) *SecuritySecretDeleteOne { + builder := c.Delete().Where(securitysecret.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &SecuritySecretDeleteOne{builder} +} + +// Query returns a query builder for SecuritySecret. +func (c *SecuritySecretClient) Query() *SecuritySecretQuery { + return &SecuritySecretQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeSecuritySecret}, + inters: c.Interceptors(), + } +} + +// Get returns a SecuritySecret entity by its id. +func (c *SecuritySecretClient) Get(ctx context.Context, id int64) (*SecuritySecret, error) { + return c.Query().Where(securitysecret.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *SecuritySecretClient) GetX(ctx context.Context, id int64) *SecuritySecret { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *SecuritySecretClient) Hooks() []Hook { + return c.hooks.SecuritySecret +} + +// Interceptors returns the client interceptors. +func (c *SecuritySecretClient) Interceptors() []Interceptor { + return c.inters.SecuritySecret +} + +func (c *SecuritySecretClient) mutate(ctx context.Context, m *SecuritySecretMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&SecuritySecretCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&SecuritySecretUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&SecuritySecretDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown SecuritySecret mutation op: %q", m.Op()) + } +} + // SettingClient is a client for the Setting schema. type SettingClient struct { config @@ -3606,15 +3888,17 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription type ( hooks struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook + ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage, + Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, + UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, + UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor + ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage, + Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, + UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, + UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 5767a167..5197e4d8 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -19,10 +19,12 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -98,10 +100,12 @@ func checkColumn(t, c string) error { announcementread.Table: announcementread.ValidColumn, errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, group.Table: group.ValidColumn, + idempotencyrecord.Table: idempotencyrecord.ValidColumn, promocode.Table: promocode.ValidColumn, promocodeusage.Table: promocodeusage.ValidColumn, proxy.Table: proxy.ValidColumn, redeemcode.Table: redeemcode.ValidColumn, + securitysecret.Table: securitysecret.ValidColumn, setting.Table: setting.ValidColumn, usagecleanuptask.Table: usagecleanuptask.ValidColumn, usagelog.Table: usagelog.ValidColumn, diff --git a/backend/ent/errorpassthroughrule.go b/backend/ent/errorpassthroughrule.go index 1932f626..62468719 100644 --- a/backend/ent/errorpassthroughrule.go +++ b/backend/ent/errorpassthroughrule.go @@ -44,6 +44,8 @@ type ErrorPassthroughRule struct { PassthroughBody bool `json:"passthrough_body,omitempty"` // CustomMessage holds the value of the "custom_message" field. CustomMessage *string `json:"custom_message,omitempty"` + // SkipMonitoring holds the value of the "skip_monitoring" field. + SkipMonitoring bool `json:"skip_monitoring,omitempty"` // Description holds the value of the "description" field. Description *string `json:"description,omitempty"` selectValues sql.SelectValues @@ -56,7 +58,7 @@ func (*ErrorPassthroughRule) scanValues(columns []string) ([]any, error) { switch columns[i] { case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms: values[i] = new([]byte) - case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody: + case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody, errorpassthroughrule.FieldSkipMonitoring: values[i] = new(sql.NullBool) case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode: values[i] = new(sql.NullInt64) @@ -171,6 +173,12 @@ func (_m *ErrorPassthroughRule) assignValues(columns []string, values []any) err _m.CustomMessage = new(string) *_m.CustomMessage = value.String } + case errorpassthroughrule.FieldSkipMonitoring: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field skip_monitoring", values[i]) + } else if value.Valid { + _m.SkipMonitoring = value.Bool + } case errorpassthroughrule.FieldDescription: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field description", values[i]) @@ -257,6 +265,9 @@ func (_m *ErrorPassthroughRule) String() string { builder.WriteString(*v) } builder.WriteString(", ") + builder.WriteString("skip_monitoring=") + builder.WriteString(fmt.Sprintf("%v", _m.SkipMonitoring)) + builder.WriteString(", ") if v := _m.Description; v != nil { builder.WriteString("description=") builder.WriteString(*v) diff --git a/backend/ent/errorpassthroughrule/errorpassthroughrule.go b/backend/ent/errorpassthroughrule/errorpassthroughrule.go index d7be4f03..859fc761 100644 --- a/backend/ent/errorpassthroughrule/errorpassthroughrule.go +++ b/backend/ent/errorpassthroughrule/errorpassthroughrule.go @@ -39,6 +39,8 @@ const ( FieldPassthroughBody = "passthrough_body" // FieldCustomMessage holds the string denoting the custom_message field in the database. FieldCustomMessage = "custom_message" + // FieldSkipMonitoring holds the string denoting the skip_monitoring field in the database. + FieldSkipMonitoring = "skip_monitoring" // FieldDescription holds the string denoting the description field in the database. FieldDescription = "description" // Table holds the table name of the errorpassthroughrule in the database. @@ -61,6 +63,7 @@ var Columns = []string{ FieldResponseCode, FieldPassthroughBody, FieldCustomMessage, + FieldSkipMonitoring, FieldDescription, } @@ -95,6 +98,8 @@ var ( DefaultPassthroughCode bool // DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field. DefaultPassthroughBody bool + // DefaultSkipMonitoring holds the default value on creation for the "skip_monitoring" field. + DefaultSkipMonitoring bool ) // OrderOption defines the ordering options for the ErrorPassthroughRule queries. @@ -155,6 +160,11 @@ func ByCustomMessage(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCustomMessage, opts...).ToFunc() } +// BySkipMonitoring orders the results by the skip_monitoring field. +func BySkipMonitoring(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSkipMonitoring, opts...).ToFunc() +} + // ByDescription orders the results by the description field. func ByDescription(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDescription, opts...).ToFunc() diff --git a/backend/ent/errorpassthroughrule/where.go b/backend/ent/errorpassthroughrule/where.go index 56839d52..87654678 100644 --- a/backend/ent/errorpassthroughrule/where.go +++ b/backend/ent/errorpassthroughrule/where.go @@ -104,6 +104,11 @@ func CustomMessage(v string) predicate.ErrorPassthroughRule { return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v)) } +// SkipMonitoring applies equality check predicate on the "skip_monitoring" field. It's identical to SkipMonitoringEQ. +func SkipMonitoring(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v)) +} + // Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ. func Description(v string) predicate.ErrorPassthroughRule { return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v)) @@ -544,6 +549,16 @@ func CustomMessageContainsFold(v string) predicate.ErrorPassthroughRule { return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v)) } +// SkipMonitoringEQ applies the EQ predicate on the "skip_monitoring" field. +func SkipMonitoringEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v)) +} + +// SkipMonitoringNEQ applies the NEQ predicate on the "skip_monitoring" field. +func SkipMonitoringNEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldSkipMonitoring, v)) +} + // DescriptionEQ applies the EQ predicate on the "description" field. func DescriptionEQ(v string) predicate.ErrorPassthroughRule { return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v)) diff --git a/backend/ent/errorpassthroughrule_create.go b/backend/ent/errorpassthroughrule_create.go index 4dc08dce..8173936b 100644 --- a/backend/ent/errorpassthroughrule_create.go +++ b/backend/ent/errorpassthroughrule_create.go @@ -172,6 +172,20 @@ func (_c *ErrorPassthroughRuleCreate) SetNillableCustomMessage(v *string) *Error return _c } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (_c *ErrorPassthroughRuleCreate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleCreate { + _c.mutation.SetSkipMonitoring(v) + return _c +} + +// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetSkipMonitoring(*v) + } + return _c +} + // SetDescription sets the "description" field. func (_c *ErrorPassthroughRuleCreate) SetDescription(v string) *ErrorPassthroughRuleCreate { _c.mutation.SetDescription(v) @@ -249,6 +263,10 @@ func (_c *ErrorPassthroughRuleCreate) defaults() { v := errorpassthroughrule.DefaultPassthroughBody _c.mutation.SetPassthroughBody(v) } + if _, ok := _c.mutation.SkipMonitoring(); !ok { + v := errorpassthroughrule.DefaultSkipMonitoring + _c.mutation.SetSkipMonitoring(v) + } } // check runs all checks and user-defined validators on the builder. @@ -287,6 +305,9 @@ func (_c *ErrorPassthroughRuleCreate) check() error { if _, ok := _c.mutation.PassthroughBody(); !ok { return &ValidationError{Name: "passthrough_body", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_body"`)} } + if _, ok := _c.mutation.SkipMonitoring(); !ok { + return &ValidationError{Name: "skip_monitoring", err: errors.New(`ent: missing required field "ErrorPassthroughRule.skip_monitoring"`)} + } return nil } @@ -366,6 +387,10 @@ func (_c *ErrorPassthroughRuleCreate) createSpec() (*ErrorPassthroughRule, *sqlg _spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value) _node.CustomMessage = &value } + if value, ok := _c.mutation.SkipMonitoring(); ok { + _spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value) + _node.SkipMonitoring = value + } if value, ok := _c.mutation.Description(); ok { _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) _node.Description = &value @@ -608,6 +633,18 @@ func (u *ErrorPassthroughRuleUpsert) ClearCustomMessage() *ErrorPassthroughRuleU return u } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (u *ErrorPassthroughRuleUpsert) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldSkipMonitoring, v) + return u +} + +// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldSkipMonitoring) + return u +} + // SetDescription sets the "description" field. func (u *ErrorPassthroughRuleUpsert) SetDescription(v string) *ErrorPassthroughRuleUpsert { u.Set(errorpassthroughrule.FieldDescription, v) @@ -888,6 +925,20 @@ func (u *ErrorPassthroughRuleUpsertOne) ClearCustomMessage() *ErrorPassthroughRu }) } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (u *ErrorPassthroughRuleUpsertOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetSkipMonitoring(v) + }) +} + +// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateSkipMonitoring() + }) +} + // SetDescription sets the "description" field. func (u *ErrorPassthroughRuleUpsertOne) SetDescription(v string) *ErrorPassthroughRuleUpsertOne { return u.Update(func(s *ErrorPassthroughRuleUpsert) { @@ -1337,6 +1388,20 @@ func (u *ErrorPassthroughRuleUpsertBulk) ClearCustomMessage() *ErrorPassthroughR }) } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetSkipMonitoring(v) + }) +} + +// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateSkipMonitoring() + }) +} + // SetDescription sets the "description" field. func (u *ErrorPassthroughRuleUpsertBulk) SetDescription(v string) *ErrorPassthroughRuleUpsertBulk { return u.Update(func(s *ErrorPassthroughRuleUpsert) { diff --git a/backend/ent/errorpassthroughrule_update.go b/backend/ent/errorpassthroughrule_update.go index 9d52aa49..7e42d9fc 100644 --- a/backend/ent/errorpassthroughrule_update.go +++ b/backend/ent/errorpassthroughrule_update.go @@ -227,6 +227,20 @@ func (_u *ErrorPassthroughRuleUpdate) ClearCustomMessage() *ErrorPassthroughRule return _u } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (_u *ErrorPassthroughRuleUpdate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdate { + _u.mutation.SetSkipMonitoring(v) + return _u +} + +// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetSkipMonitoring(*v) + } + return _u +} + // SetDescription sets the "description" field. func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate { _u.mutation.SetDescription(v) @@ -387,6 +401,9 @@ func (_u *ErrorPassthroughRuleUpdate) sqlSave(ctx context.Context) (_node int, e if _u.mutation.CustomMessageCleared() { _spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString) } + if value, ok := _u.mutation.SkipMonitoring(); ok { + _spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value) + } if value, ok := _u.mutation.Description(); ok { _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) } @@ -611,6 +628,20 @@ func (_u *ErrorPassthroughRuleUpdateOne) ClearCustomMessage() *ErrorPassthroughR return _u } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetSkipMonitoring(v) + return _u +} + +// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetSkipMonitoring(*v) + } + return _u +} + // SetDescription sets the "description" field. func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne { _u.mutation.SetDescription(v) @@ -801,6 +832,9 @@ func (_u *ErrorPassthroughRuleUpdateOne) sqlSave(ctx context.Context) (_node *Er if _u.mutation.CustomMessageCleared() { _spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString) } + if value, ok := _u.mutation.SkipMonitoring(); ok { + _spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value) + } if value, ok := _u.mutation.Description(); ok { _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) } diff --git a/backend/ent/group.go b/backend/ent/group.go index 1eb05e0e..76c3cae2 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -52,6 +52,16 @@ type Group struct { ImagePrice2k *float64 `json:"image_price_2k,omitempty"` // ImagePrice4k holds the value of the "image_price_4k" field. ImagePrice4k *float64 `json:"image_price_4k,omitempty"` + // SoraImagePrice360 holds the value of the "sora_image_price_360" field. + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + // SoraImagePrice540 holds the value of the "sora_image_price_540" field. + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + // SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field. + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + // SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field. + SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"` + // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` // 是否仅允许 Claude Code 客户端 ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID @@ -66,6 +76,8 @@ type Group struct { McpXMLInject bool `json:"mcp_xml_inject,omitempty"` // 支持的模型系列:claude, gemini_text, gemini_image SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` + // 分组显示排序,数值越小越靠前 + SortOrder int `json:"sort_order,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -176,9 +188,9 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject: values[i] = new(sql.NullBool) - case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: + case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -315,6 +327,40 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.ImagePrice4k = new(float64) *_m.ImagePrice4k = value.Float64 } + case group.FieldSoraImagePrice360: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i]) + } else if value.Valid { + _m.SoraImagePrice360 = new(float64) + *_m.SoraImagePrice360 = value.Float64 + } + case group.FieldSoraImagePrice540: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i]) + } else if value.Valid { + _m.SoraImagePrice540 = new(float64) + *_m.SoraImagePrice540 = value.Float64 + } + case group.FieldSoraVideoPricePerRequest: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i]) + } else if value.Valid { + _m.SoraVideoPricePerRequest = new(float64) + *_m.SoraVideoPricePerRequest = value.Float64 + } + case group.FieldSoraVideoPricePerRequestHd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i]) + } else if value.Valid { + _m.SoraVideoPricePerRequestHd = new(float64) + *_m.SoraVideoPricePerRequestHd = value.Float64 + } + case group.FieldSoraStorageQuotaBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i]) + } else if value.Valid { + _m.SoraStorageQuotaBytes = value.Int64 + } case group.FieldClaudeCodeOnly: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field claude_code_only", values[i]) @@ -363,6 +409,12 @@ func (_m *Group) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field supported_model_scopes: %w", err) } } + case group.FieldSortOrder: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sort_order", values[i]) + } else if value.Valid { + _m.SortOrder = int(value.Int64) + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -506,6 +558,29 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") + if v := _m.SoraImagePrice360; v != nil { + builder.WriteString("sora_image_price_360=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraImagePrice540; v != nil { + builder.WriteString("sora_image_price_540=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraVideoPricePerRequest; v != nil { + builder.WriteString("sora_video_price_per_request=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraVideoPricePerRequestHd; v != nil { + builder.WriteString("sora_video_price_per_request_hd=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("sora_storage_quota_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes)) + builder.WriteString(", ") builder.WriteString("claude_code_only=") builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly)) builder.WriteString(", ") @@ -530,6 +605,9 @@ func (_m *Group) String() string { builder.WriteString(", ") builder.WriteString("supported_model_scopes=") builder.WriteString(fmt.Sprintf("%v", _m.SupportedModelScopes)) + builder.WriteString(", ") + builder.WriteString("sort_order=") + builder.WriteString(fmt.Sprintf("%v", _m.SortOrder)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 278b2daf..6ac4eea1 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -49,6 +49,16 @@ const ( FieldImagePrice2k = "image_price_2k" // FieldImagePrice4k holds the string denoting the image_price_4k field in the database. FieldImagePrice4k = "image_price_4k" + // FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database. + FieldSoraImagePrice360 = "sora_image_price_360" + // FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database. + FieldSoraImagePrice540 = "sora_image_price_540" + // FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database. + FieldSoraVideoPricePerRequest = "sora_video_price_per_request" + // FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database. + FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd" + // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database. + FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes" // FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database. FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. @@ -63,6 +73,8 @@ const ( FieldMcpXMLInject = "mcp_xml_inject" // FieldSupportedModelScopes holds the string denoting the supported_model_scopes field in the database. FieldSupportedModelScopes = "supported_model_scopes" + // FieldSortOrder holds the string denoting the sort_order field in the database. + FieldSortOrder = "sort_order" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -155,6 +167,11 @@ var Columns = []string{ FieldImagePrice1k, FieldImagePrice2k, FieldImagePrice4k, + FieldSoraImagePrice360, + FieldSoraImagePrice540, + FieldSoraVideoPricePerRequest, + FieldSoraVideoPricePerRequestHd, + FieldSoraStorageQuotaBytes, FieldClaudeCodeOnly, FieldFallbackGroupID, FieldFallbackGroupIDOnInvalidRequest, @@ -162,6 +179,7 @@ var Columns = []string{ FieldModelRoutingEnabled, FieldMcpXMLInject, FieldSupportedModelScopes, + FieldSortOrder, } var ( @@ -217,6 +235,8 @@ var ( SubscriptionTypeValidator func(string) error // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. DefaultDefaultValidityDays int + // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field. + DefaultSoraStorageQuotaBytes int64 // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. DefaultClaudeCodeOnly bool // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. @@ -225,6 +245,8 @@ var ( DefaultMcpXMLInject bool // DefaultSupportedModelScopes holds the default value on creation for the "supported_model_scopes" field. DefaultSupportedModelScopes []string + // DefaultSortOrder holds the default value on creation for the "sort_order" field. + DefaultSortOrder int ) // OrderOption defines the ordering options for the Group queries. @@ -320,6 +342,31 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc() } +// BySoraImagePrice360 orders the results by the sora_image_price_360 field. +func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc() +} + +// BySoraImagePrice540 orders the results by the sora_image_price_540 field. +func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc() +} + +// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field. +func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc() +} + +// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field. +func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc() +} + +// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field. +func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc() +} + // ByClaudeCodeOnly orders the results by the claude_code_only field. func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc() @@ -345,6 +392,11 @@ func ByMcpXMLInject(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldMcpXMLInject, opts...).ToFunc() } +// BySortOrder orders the results by the sort_order field. +func BySortOrder(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSortOrder, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index b6fa2c33..4cf65d0f 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -140,6 +140,31 @@ func ImagePrice4k(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v)) } +// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ. +func SoraImagePrice360(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ. +func SoraImagePrice540(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) +} + +// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ. +func SoraVideoPricePerRequest(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ. +func SoraVideoPricePerRequestHd(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ. +func SoraStorageQuotaBytes(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + // ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ. func ClaudeCodeOnly(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) @@ -165,6 +190,11 @@ func McpXMLInject(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v)) } +// SortOrder applies equality check predicate on the "sort_order" field. It's identical to SortOrderEQ. +func SortOrder(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSortOrder, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -1020,6 +1050,246 @@ func ImagePrice4kNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldImagePrice4k)) } +// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field. +func SoraImagePrice360EQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field. +func SoraImagePrice360NEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field. +func SoraImagePrice360In(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...)) +} + +// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field. +func SoraImagePrice360NotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...)) +} + +// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field. +func SoraImagePrice360GT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field. +func SoraImagePrice360GTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field. +func SoraImagePrice360LT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field. +func SoraImagePrice360LTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field. +func SoraImagePrice360IsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360)) +} + +// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field. +func SoraImagePrice360NotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360)) +} + +// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field. +func SoraImagePrice540EQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field. +func SoraImagePrice540NEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field. +func SoraImagePrice540In(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...)) +} + +// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field. +func SoraImagePrice540NotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...)) +} + +// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field. +func SoraImagePrice540GT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field. +func SoraImagePrice540GTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field. +func SoraImagePrice540LT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field. +func SoraImagePrice540LTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field. +func SoraImagePrice540IsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540)) +} + +// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field. +func SoraImagePrice540NotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540)) +} + +// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...)) +} + +// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...)) +} + +// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestGT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestGTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestLT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestLTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest)) +} + +// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest)) +} + +// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...)) +} + +// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...)) +} + +// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdGT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdLT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd)) +} + +// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd)) +} + +// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGT(v int64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGTE(v int64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLT(v int64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLTE(v int64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v)) +} + // ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field. func ClaudeCodeOnlyEQ(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) @@ -1160,6 +1430,46 @@ func McpXMLInjectNEQ(v bool) predicate.Group { return predicate.Group(sql.FieldNEQ(FieldMcpXMLInject, v)) } +// SortOrderEQ applies the EQ predicate on the "sort_order" field. +func SortOrderEQ(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSortOrder, v)) +} + +// SortOrderNEQ applies the NEQ predicate on the "sort_order" field. +func SortOrderNEQ(v int) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSortOrder, v)) +} + +// SortOrderIn applies the In predicate on the "sort_order" field. +func SortOrderIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSortOrder, vs...)) +} + +// SortOrderNotIn applies the NotIn predicate on the "sort_order" field. +func SortOrderNotIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSortOrder, vs...)) +} + +// SortOrderGT applies the GT predicate on the "sort_order" field. +func SortOrderGT(v int) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSortOrder, v)) +} + +// SortOrderGTE applies the GTE predicate on the "sort_order" field. +func SortOrderGTE(v int) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSortOrder, v)) +} + +// SortOrderLT applies the LT predicate on the "sort_order" field. +func SortOrderLT(v int) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSortOrder, v)) +} + +// SortOrderLTE applies the LTE predicate on the "sort_order" field. +func SortOrderLTE(v int) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSortOrder, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 9d845b61..0ce5f959 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -258,6 +258,76 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate { return _c } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate { + _c.mutation.SetSoraImagePrice360(v) + return _c +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraImagePrice360(*v) + } + return _c +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate { + _c.mutation.SetSoraImagePrice540(v) + return _c +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraImagePrice540(*v) + } + return _c +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate { + _c.mutation.SetSoraVideoPricePerRequest(v) + return _c +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraVideoPricePerRequest(*v) + } + return _c +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate { + _c.mutation.SetSoraVideoPricePerRequestHd(v) + return _c +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraVideoPricePerRequestHd(*v) + } + return _c +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate { + _c.mutation.SetSoraStorageQuotaBytes(v) + return _c +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate { + if v != nil { + _c.SetSoraStorageQuotaBytes(*v) + } + return _c +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate { _c.mutation.SetClaudeCodeOnly(v) @@ -340,6 +410,20 @@ func (_c *GroupCreate) SetSupportedModelScopes(v []string) *GroupCreate { return _c } +// SetSortOrder sets the "sort_order" field. +func (_c *GroupCreate) SetSortOrder(v int) *GroupCreate { + _c.mutation.SetSortOrder(v) + return _c +} + +// SetNillableSortOrder sets the "sort_order" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSortOrder(v *int) *GroupCreate { + if v != nil { + _c.SetSortOrder(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -505,6 +589,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultValidityDays _c.mutation.SetDefaultValidityDays(v) } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + v := group.DefaultSoraStorageQuotaBytes + _c.mutation.SetSoraStorageQuotaBytes(v) + } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { v := group.DefaultClaudeCodeOnly _c.mutation.SetClaudeCodeOnly(v) @@ -521,6 +609,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultSupportedModelScopes _c.mutation.SetSupportedModelScopes(v) } + if _, ok := _c.mutation.SortOrder(); !ok { + v := group.DefaultSortOrder + _c.mutation.SetSortOrder(v) + } return nil } @@ -573,6 +665,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.DefaultValidityDays(); !ok { return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)} + } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} } @@ -585,6 +680,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.SupportedModelScopes(); !ok { return &ValidationError{Name: "supported_model_scopes", err: errors.New(`ent: missing required field "Group.supported_model_scopes"`)} } + if _, ok := _c.mutation.SortOrder(); !ok { + return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "Group.sort_order"`)} + } return nil } @@ -680,6 +778,26 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value) _node.ImagePrice4k = &value } + if value, ok := _c.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + _node.SoraImagePrice360 = &value + } + if value, ok := _c.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + _node.SoraImagePrice540 = &value + } + if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + _node.SoraVideoPricePerRequest = &value + } + if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + _node.SoraVideoPricePerRequestHd = &value + } + if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + _node.SoraStorageQuotaBytes = value + } if value, ok := _c.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) _node.ClaudeCodeOnly = value @@ -708,6 +826,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value) _node.SupportedModelScopes = value } + if value, ok := _c.mutation.SortOrder(); ok { + _spec.SetField(group.FieldSortOrder, field.TypeInt, value) + _node.SortOrder = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1152,6 +1274,120 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert { return u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert { + u.Set(group.FieldSoraImagePrice360, v) + return u +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert { + u.SetExcluded(group.FieldSoraImagePrice360) + return u +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert { + u.Add(group.FieldSoraImagePrice360, v) + return u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert { + u.SetNull(group.FieldSoraImagePrice360) + return u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert { + u.Set(group.FieldSoraImagePrice540, v) + return u +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert { + u.SetExcluded(group.FieldSoraImagePrice540) + return u +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert { + u.Add(group.FieldSoraImagePrice540, v) + return u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert { + u.SetNull(group.FieldSoraImagePrice540) + return u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert { + u.Set(group.FieldSoraVideoPricePerRequest, v) + return u +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert { + u.SetExcluded(group.FieldSoraVideoPricePerRequest) + return u +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert { + u.Add(group.FieldSoraVideoPricePerRequest, v) + return u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert { + u.SetNull(group.FieldSoraVideoPricePerRequest) + return u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert { + u.Set(group.FieldSoraVideoPricePerRequestHd, v) + return u +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert { + u.SetExcluded(group.FieldSoraVideoPricePerRequestHd) + return u +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert { + u.Add(group.FieldSoraVideoPricePerRequestHd, v) + return u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert { + u.SetNull(group.FieldSoraVideoPricePerRequestHd) + return u +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert { + u.Set(group.FieldSoraStorageQuotaBytes, v) + return u +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert { + u.SetExcluded(group.FieldSoraStorageQuotaBytes) + return u +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert { + u.Add(group.FieldSoraStorageQuotaBytes, v) + return u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert { u.Set(group.FieldClaudeCodeOnly, v) @@ -1266,6 +1502,24 @@ func (u *GroupUpsert) UpdateSupportedModelScopes() *GroupUpsert { return u } +// SetSortOrder sets the "sort_order" field. +func (u *GroupUpsert) SetSortOrder(v int) *GroupUpsert { + u.Set(group.FieldSortOrder, v) + return u +} + +// UpdateSortOrder sets the "sort_order" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSortOrder() *GroupUpsert { + u.SetExcluded(group.FieldSortOrder) + return u +} + +// AddSortOrder adds v to the "sort_order" field. +func (u *GroupUpsert) AddSortOrder(v int) *GroupUpsert { + u.Add(group.FieldSortOrder, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1647,6 +1901,139 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne { }) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice360(v) + }) +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice360(v) + }) +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice360() + }) +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice360() + }) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice540(v) + }) +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice540(v) + }) +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice540() + }) +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice540() + }) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequest(v) + }) +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequest(v) + }) +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequest() + }) +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequest() + }) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequestHd(v) + }) +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequestHd(v) + }) +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequestHd() + }) +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequestHd() + }) +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -1780,6 +2167,27 @@ func (u *GroupUpsertOne) UpdateSupportedModelScopes() *GroupUpsertOne { }) } +// SetSortOrder sets the "sort_order" field. +func (u *GroupUpsertOne) SetSortOrder(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSortOrder(v) + }) +} + +// AddSortOrder adds v to the "sort_order" field. +func (u *GroupUpsertOne) AddSortOrder(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSortOrder(v) + }) +} + +// UpdateSortOrder sets the "sort_order" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSortOrder() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSortOrder() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2327,6 +2735,139 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk { }) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice360(v) + }) +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice360(v) + }) +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice360() + }) +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice360() + }) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice540(v) + }) +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice540(v) + }) +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice540() + }) +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice540() + }) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequest(v) + }) +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequest(v) + }) +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequest() + }) +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequest() + }) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequestHd(v) + }) +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequestHd(v) + }) +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequestHd() + }) +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequestHd() + }) +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { @@ -2460,6 +3001,27 @@ func (u *GroupUpsertBulk) UpdateSupportedModelScopes() *GroupUpsertBulk { }) } +// SetSortOrder sets the "sort_order" field. +func (u *GroupUpsertBulk) SetSortOrder(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSortOrder(v) + }) +} + +// AddSortOrder adds v to the "sort_order" field. +func (u *GroupUpsertBulk) AddSortOrder(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSortOrder(v) + }) +} + +// UpdateSortOrder sets the "sort_order" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSortOrder() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSortOrder() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 9e7246ea..85575292 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -355,6 +355,135 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate { return _u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate { + _u.mutation.ResetSoraImagePrice360() + _u.mutation.SetSoraImagePrice360(v) + return _u +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraImagePrice360(*v) + } + return _u +} + +// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. +func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate { + _u.mutation.AddSoraImagePrice360(v) + return _u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate { + _u.mutation.ClearSoraImagePrice360() + return _u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate { + _u.mutation.ResetSoraImagePrice540() + _u.mutation.SetSoraImagePrice540(v) + return _u +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraImagePrice540(*v) + } + return _u +} + +// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. +func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate { + _u.mutation.AddSoraImagePrice540(v) + return _u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate { + _u.mutation.ClearSoraImagePrice540() + return _u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate { + _u.mutation.ResetSoraVideoPricePerRequest() + _u.mutation.SetSoraVideoPricePerRequest(v) + return _u +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraVideoPricePerRequest(*v) + } + return _u +} + +// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. +func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate { + _u.mutation.AddSoraVideoPricePerRequest(v) + return _u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate { + _u.mutation.ClearSoraVideoPricePerRequest() + return _u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate { + _u.mutation.ResetSoraVideoPricePerRequestHd() + _u.mutation.SetSoraVideoPricePerRequestHd(v) + return _u +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraVideoPricePerRequestHd(*v) + } + return _u +} + +// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate { + _u.mutation.AddSoraVideoPricePerRequestHd(v) + return _u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate { + _u.mutation.ClearSoraVideoPricePerRequestHd() + return _u +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate { _u.mutation.SetClaudeCodeOnly(v) @@ -475,6 +604,27 @@ func (_u *GroupUpdate) AppendSupportedModelScopes(v []string) *GroupUpdate { return _u } +// SetSortOrder sets the "sort_order" field. +func (_u *GroupUpdate) SetSortOrder(v int) *GroupUpdate { + _u.mutation.ResetSortOrder() + _u.mutation.SetSortOrder(v) + return _u +} + +// SetNillableSortOrder sets the "sort_order" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSortOrder(v *int) *GroupUpdate { + if v != nil { + _u.SetSortOrder(*v) + } + return _u +} + +// AddSortOrder adds value to the "sort_order" field. +func (_u *GroupUpdate) AddSortOrder(v int) *GroupUpdate { + _u.mutation.AddSortOrder(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -871,6 +1021,48 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { + _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice360Cleared() { + _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { + _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice540Cleared() { + _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestHdCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } @@ -912,6 +1104,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { sqljson.Append(u, group.FieldSupportedModelScopes, value) }) } + if value, ok := _u.mutation.SortOrder(); ok { + _spec.SetField(group.FieldSortOrder, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedSortOrder(); ok { + _spec.AddField(group.FieldSortOrder, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1546,6 +1744,135 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne { return _u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraImagePrice360() + _u.mutation.SetSoraImagePrice360(v) + return _u +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraImagePrice360(*v) + } + return _u +} + +// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. +func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne { + _u.mutation.AddSoraImagePrice360(v) + return _u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne { + _u.mutation.ClearSoraImagePrice360() + return _u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraImagePrice540() + _u.mutation.SetSoraImagePrice540(v) + return _u +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraImagePrice540(*v) + } + return _u +} + +// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. +func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne { + _u.mutation.AddSoraImagePrice540(v) + return _u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne { + _u.mutation.ClearSoraImagePrice540() + return _u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraVideoPricePerRequest() + _u.mutation.SetSoraVideoPricePerRequest(v) + return _u +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraVideoPricePerRequest(*v) + } + return _u +} + +// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne { + _u.mutation.AddSoraVideoPricePerRequest(v) + return _u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne { + _u.mutation.ClearSoraVideoPricePerRequest() + return _u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraVideoPricePerRequestHd() + _u.mutation.SetSoraVideoPricePerRequestHd(v) + return _u +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraVideoPricePerRequestHd(*v) + } + return _u +} + +// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { + _u.mutation.AddSoraVideoPricePerRequestHd(v) + return _u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne { + _u.mutation.ClearSoraVideoPricePerRequestHd() + return _u +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne { _u.mutation.SetClaudeCodeOnly(v) @@ -1666,6 +1993,27 @@ func (_u *GroupUpdateOne) AppendSupportedModelScopes(v []string) *GroupUpdateOne return _u } +// SetSortOrder sets the "sort_order" field. +func (_u *GroupUpdateOne) SetSortOrder(v int) *GroupUpdateOne { + _u.mutation.ResetSortOrder() + _u.mutation.SetSortOrder(v) + return _u +} + +// SetNillableSortOrder sets the "sort_order" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSortOrder(v *int) *GroupUpdateOne { + if v != nil { + _u.SetSortOrder(*v) + } + return _u +} + +// AddSortOrder adds value to the "sort_order" field. +func (_u *GroupUpdateOne) AddSortOrder(v int) *GroupUpdateOne { + _u.mutation.AddSortOrder(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2092,6 +2440,48 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { + _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice360Cleared() { + _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { + _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice540Cleared() { + _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestHdCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } @@ -2133,6 +2523,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) sqljson.Append(u, group.FieldSupportedModelScopes, value) }) } + if value, ok := _u.mutation.SortOrder(); ok { + _spec.SetField(group.FieldSortOrder, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedSortOrder(); ok { + _spec.AddField(group.FieldSortOrder, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 1b15685c..49d7f3c5 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -93,6 +93,18 @@ func (f GroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.GroupMutation", m) } +// The IdempotencyRecordFunc type is an adapter to allow the use of ordinary +// function as IdempotencyRecord mutator. +type IdempotencyRecordFunc func(context.Context, *ent.IdempotencyRecordMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.IdempotencyRecordMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary // function as PromoCode mutator. type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error) @@ -141,6 +153,18 @@ func (f RedeemCodeFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.RedeemCodeMutation", m) } +// The SecuritySecretFunc type is an adapter to allow the use of ordinary +// function as SecuritySecret mutator. +type SecuritySecretFunc func(context.Context, *ent.SecuritySecretMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f SecuritySecretFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.SecuritySecretMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SecuritySecretMutation", m) +} + // The SettingFunc type is an adapter to allow the use of ordinary // function as Setting mutator. type SettingFunc func(context.Context, *ent.SettingMutation) (ent.Value, error) diff --git a/backend/ent/idempotencyrecord.go b/backend/ent/idempotencyrecord.go new file mode 100644 index 00000000..ab120f8f --- /dev/null +++ b/backend/ent/idempotencyrecord.go @@ -0,0 +1,228 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" +) + +// IdempotencyRecord is the model entity for the IdempotencyRecord schema. +type IdempotencyRecord struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Scope holds the value of the "scope" field. + Scope string `json:"scope,omitempty"` + // IdempotencyKeyHash holds the value of the "idempotency_key_hash" field. + IdempotencyKeyHash string `json:"idempotency_key_hash,omitempty"` + // RequestFingerprint holds the value of the "request_fingerprint" field. + RequestFingerprint string `json:"request_fingerprint,omitempty"` + // Status holds the value of the "status" field. + Status string `json:"status,omitempty"` + // ResponseStatus holds the value of the "response_status" field. + ResponseStatus *int `json:"response_status,omitempty"` + // ResponseBody holds the value of the "response_body" field. + ResponseBody *string `json:"response_body,omitempty"` + // ErrorReason holds the value of the "error_reason" field. + ErrorReason *string `json:"error_reason,omitempty"` + // LockedUntil holds the value of the "locked_until" field. + LockedUntil *time.Time `json:"locked_until,omitempty"` + // ExpiresAt holds the value of the "expires_at" field. + ExpiresAt time.Time `json:"expires_at,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*IdempotencyRecord) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case idempotencyrecord.FieldID, idempotencyrecord.FieldResponseStatus: + values[i] = new(sql.NullInt64) + case idempotencyrecord.FieldScope, idempotencyrecord.FieldIdempotencyKeyHash, idempotencyrecord.FieldRequestFingerprint, idempotencyrecord.FieldStatus, idempotencyrecord.FieldResponseBody, idempotencyrecord.FieldErrorReason: + values[i] = new(sql.NullString) + case idempotencyrecord.FieldCreatedAt, idempotencyrecord.FieldUpdatedAt, idempotencyrecord.FieldLockedUntil, idempotencyrecord.FieldExpiresAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the IdempotencyRecord fields. +func (_m *IdempotencyRecord) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case idempotencyrecord.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case idempotencyrecord.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case idempotencyrecord.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case idempotencyrecord.FieldScope: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field scope", values[i]) + } else if value.Valid { + _m.Scope = value.String + } + case idempotencyrecord.FieldIdempotencyKeyHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field idempotency_key_hash", values[i]) + } else if value.Valid { + _m.IdempotencyKeyHash = value.String + } + case idempotencyrecord.FieldRequestFingerprint: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field request_fingerprint", values[i]) + } else if value.Valid { + _m.RequestFingerprint = value.String + } + case idempotencyrecord.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + _m.Status = value.String + } + case idempotencyrecord.FieldResponseStatus: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field response_status", values[i]) + } else if value.Valid { + _m.ResponseStatus = new(int) + *_m.ResponseStatus = int(value.Int64) + } + case idempotencyrecord.FieldResponseBody: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field response_body", values[i]) + } else if value.Valid { + _m.ResponseBody = new(string) + *_m.ResponseBody = value.String + } + case idempotencyrecord.FieldErrorReason: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field error_reason", values[i]) + } else if value.Valid { + _m.ErrorReason = new(string) + *_m.ErrorReason = value.String + } + case idempotencyrecord.FieldLockedUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field locked_until", values[i]) + } else if value.Valid { + _m.LockedUntil = new(time.Time) + *_m.LockedUntil = value.Time + } + case idempotencyrecord.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the IdempotencyRecord. +// This includes values selected through modifiers, order, etc. +func (_m *IdempotencyRecord) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this IdempotencyRecord. +// Note that you need to call IdempotencyRecord.Unwrap() before calling this method if this IdempotencyRecord +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *IdempotencyRecord) Update() *IdempotencyRecordUpdateOne { + return NewIdempotencyRecordClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the IdempotencyRecord entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *IdempotencyRecord) Unwrap() *IdempotencyRecord { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: IdempotencyRecord is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *IdempotencyRecord) String() string { + var builder strings.Builder + builder.WriteString("IdempotencyRecord(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("scope=") + builder.WriteString(_m.Scope) + builder.WriteString(", ") + builder.WriteString("idempotency_key_hash=") + builder.WriteString(_m.IdempotencyKeyHash) + builder.WriteString(", ") + builder.WriteString("request_fingerprint=") + builder.WriteString(_m.RequestFingerprint) + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(_m.Status) + builder.WriteString(", ") + if v := _m.ResponseStatus; v != nil { + builder.WriteString("response_status=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.ResponseBody; v != nil { + builder.WriteString("response_body=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.ErrorReason; v != nil { + builder.WriteString("error_reason=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.LockedUntil; v != nil { + builder.WriteString("locked_until=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("expires_at=") + builder.WriteString(_m.ExpiresAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// IdempotencyRecords is a parsable slice of IdempotencyRecord. +type IdempotencyRecords []*IdempotencyRecord diff --git a/backend/ent/idempotencyrecord/idempotencyrecord.go b/backend/ent/idempotencyrecord/idempotencyrecord.go new file mode 100644 index 00000000..d9686f60 --- /dev/null +++ b/backend/ent/idempotencyrecord/idempotencyrecord.go @@ -0,0 +1,148 @@ +// Code generated by ent, DO NOT EDIT. + +package idempotencyrecord + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the idempotencyrecord type in the database. + Label = "idempotency_record" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldScope holds the string denoting the scope field in the database. + FieldScope = "scope" + // FieldIdempotencyKeyHash holds the string denoting the idempotency_key_hash field in the database. + FieldIdempotencyKeyHash = "idempotency_key_hash" + // FieldRequestFingerprint holds the string denoting the request_fingerprint field in the database. + FieldRequestFingerprint = "request_fingerprint" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" + // FieldResponseStatus holds the string denoting the response_status field in the database. + FieldResponseStatus = "response_status" + // FieldResponseBody holds the string denoting the response_body field in the database. + FieldResponseBody = "response_body" + // FieldErrorReason holds the string denoting the error_reason field in the database. + FieldErrorReason = "error_reason" + // FieldLockedUntil holds the string denoting the locked_until field in the database. + FieldLockedUntil = "locked_until" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" + // Table holds the table name of the idempotencyrecord in the database. + Table = "idempotency_records" +) + +// Columns holds all SQL columns for idempotencyrecord fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldScope, + FieldIdempotencyKeyHash, + FieldRequestFingerprint, + FieldStatus, + FieldResponseStatus, + FieldResponseBody, + FieldErrorReason, + FieldLockedUntil, + FieldExpiresAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ScopeValidator is a validator for the "scope" field. It is called by the builders before save. + ScopeValidator func(string) error + // IdempotencyKeyHashValidator is a validator for the "idempotency_key_hash" field. It is called by the builders before save. + IdempotencyKeyHashValidator func(string) error + // RequestFingerprintValidator is a validator for the "request_fingerprint" field. It is called by the builders before save. + RequestFingerprintValidator func(string) error + // StatusValidator is a validator for the "status" field. It is called by the builders before save. + StatusValidator func(string) error + // ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. + ErrorReasonValidator func(string) error +) + +// OrderOption defines the ordering options for the IdempotencyRecord queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByScope orders the results by the scope field. +func ByScope(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScope, opts...).ToFunc() +} + +// ByIdempotencyKeyHash orders the results by the idempotency_key_hash field. +func ByIdempotencyKeyHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdempotencyKeyHash, opts...).ToFunc() +} + +// ByRequestFingerprint orders the results by the request_fingerprint field. +func ByRequestFingerprint(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequestFingerprint, opts...).ToFunc() +} + +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByResponseStatus orders the results by the response_status field. +func ByResponseStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResponseStatus, opts...).ToFunc() +} + +// ByResponseBody orders the results by the response_body field. +func ByResponseBody(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResponseBody, opts...).ToFunc() +} + +// ByErrorReason orders the results by the error_reason field. +func ByErrorReason(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldErrorReason, opts...).ToFunc() +} + +// ByLockedUntil orders the results by the locked_until field. +func ByLockedUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLockedUntil, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} diff --git a/backend/ent/idempotencyrecord/where.go b/backend/ent/idempotencyrecord/where.go new file mode 100644 index 00000000..c3d8d9d5 --- /dev/null +++ b/backend/ent/idempotencyrecord/where.go @@ -0,0 +1,755 @@ +// Code generated by ent, DO NOT EDIT. + +package idempotencyrecord + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Scope applies equality check predicate on the "scope" field. It's identical to ScopeEQ. +func Scope(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldScope, v)) +} + +// IdempotencyKeyHash applies equality check predicate on the "idempotency_key_hash" field. It's identical to IdempotencyKeyHashEQ. +func IdempotencyKeyHash(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldIdempotencyKeyHash, v)) +} + +// RequestFingerprint applies equality check predicate on the "request_fingerprint" field. It's identical to RequestFingerprintEQ. +func RequestFingerprint(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldRequestFingerprint, v)) +} + +// Status applies equality check predicate on the "status" field. It's identical to StatusEQ. +func Status(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldStatus, v)) +} + +// ResponseStatus applies equality check predicate on the "response_status" field. It's identical to ResponseStatusEQ. +func ResponseStatus(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseStatus, v)) +} + +// ResponseBody applies equality check predicate on the "response_body" field. It's identical to ResponseBodyEQ. +func ResponseBody(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseBody, v)) +} + +// ErrorReason applies equality check predicate on the "error_reason" field. It's identical to ErrorReasonEQ. +func ErrorReason(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldErrorReason, v)) +} + +// LockedUntil applies equality check predicate on the "locked_until" field. It's identical to LockedUntilEQ. +func LockedUntil(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldLockedUntil, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldExpiresAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// ScopeEQ applies the EQ predicate on the "scope" field. +func ScopeEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldScope, v)) +} + +// ScopeNEQ applies the NEQ predicate on the "scope" field. +func ScopeNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldScope, v)) +} + +// ScopeIn applies the In predicate on the "scope" field. +func ScopeIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldScope, vs...)) +} + +// ScopeNotIn applies the NotIn predicate on the "scope" field. +func ScopeNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldScope, vs...)) +} + +// ScopeGT applies the GT predicate on the "scope" field. +func ScopeGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldScope, v)) +} + +// ScopeGTE applies the GTE predicate on the "scope" field. +func ScopeGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldScope, v)) +} + +// ScopeLT applies the LT predicate on the "scope" field. +func ScopeLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldScope, v)) +} + +// ScopeLTE applies the LTE predicate on the "scope" field. +func ScopeLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldScope, v)) +} + +// ScopeContains applies the Contains predicate on the "scope" field. +func ScopeContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldScope, v)) +} + +// ScopeHasPrefix applies the HasPrefix predicate on the "scope" field. +func ScopeHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldScope, v)) +} + +// ScopeHasSuffix applies the HasSuffix predicate on the "scope" field. +func ScopeHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldScope, v)) +} + +// ScopeEqualFold applies the EqualFold predicate on the "scope" field. +func ScopeEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldScope, v)) +} + +// ScopeContainsFold applies the ContainsFold predicate on the "scope" field. +func ScopeContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldScope, v)) +} + +// IdempotencyKeyHashEQ applies the EQ predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashNEQ applies the NEQ predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashIn applies the In predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldIdempotencyKeyHash, vs...)) +} + +// IdempotencyKeyHashNotIn applies the NotIn predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldIdempotencyKeyHash, vs...)) +} + +// IdempotencyKeyHashGT applies the GT predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashGTE applies the GTE predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashLT applies the LT predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashLTE applies the LTE predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashContains applies the Contains predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashHasPrefix applies the HasPrefix predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashHasSuffix applies the HasSuffix predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashEqualFold applies the EqualFold predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashContainsFold applies the ContainsFold predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldIdempotencyKeyHash, v)) +} + +// RequestFingerprintEQ applies the EQ predicate on the "request_fingerprint" field. +func RequestFingerprintEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldRequestFingerprint, v)) +} + +// RequestFingerprintNEQ applies the NEQ predicate on the "request_fingerprint" field. +func RequestFingerprintNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldRequestFingerprint, v)) +} + +// RequestFingerprintIn applies the In predicate on the "request_fingerprint" field. +func RequestFingerprintIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldRequestFingerprint, vs...)) +} + +// RequestFingerprintNotIn applies the NotIn predicate on the "request_fingerprint" field. +func RequestFingerprintNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldRequestFingerprint, vs...)) +} + +// RequestFingerprintGT applies the GT predicate on the "request_fingerprint" field. +func RequestFingerprintGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldRequestFingerprint, v)) +} + +// RequestFingerprintGTE applies the GTE predicate on the "request_fingerprint" field. +func RequestFingerprintGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldRequestFingerprint, v)) +} + +// RequestFingerprintLT applies the LT predicate on the "request_fingerprint" field. +func RequestFingerprintLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldRequestFingerprint, v)) +} + +// RequestFingerprintLTE applies the LTE predicate on the "request_fingerprint" field. +func RequestFingerprintLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldRequestFingerprint, v)) +} + +// RequestFingerprintContains applies the Contains predicate on the "request_fingerprint" field. +func RequestFingerprintContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldRequestFingerprint, v)) +} + +// RequestFingerprintHasPrefix applies the HasPrefix predicate on the "request_fingerprint" field. +func RequestFingerprintHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldRequestFingerprint, v)) +} + +// RequestFingerprintHasSuffix applies the HasSuffix predicate on the "request_fingerprint" field. +func RequestFingerprintHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldRequestFingerprint, v)) +} + +// RequestFingerprintEqualFold applies the EqualFold predicate on the "request_fingerprint" field. +func RequestFingerprintEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldRequestFingerprint, v)) +} + +// RequestFingerprintContainsFold applies the ContainsFold predicate on the "request_fingerprint" field. +func RequestFingerprintContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldRequestFingerprint, v)) +} + +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldStatus, v)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldStatus, v)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldStatus, vs...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldStatus, vs...)) +} + +// StatusGT applies the GT predicate on the "status" field. +func StatusGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldStatus, v)) +} + +// StatusGTE applies the GTE predicate on the "status" field. +func StatusGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldStatus, v)) +} + +// StatusLT applies the LT predicate on the "status" field. +func StatusLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldStatus, v)) +} + +// StatusLTE applies the LTE predicate on the "status" field. +func StatusLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldStatus, v)) +} + +// StatusContains applies the Contains predicate on the "status" field. +func StatusContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldStatus, v)) +} + +// StatusHasPrefix applies the HasPrefix predicate on the "status" field. +func StatusHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldStatus, v)) +} + +// StatusHasSuffix applies the HasSuffix predicate on the "status" field. +func StatusHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldStatus, v)) +} + +// StatusEqualFold applies the EqualFold predicate on the "status" field. +func StatusEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldStatus, v)) +} + +// StatusContainsFold applies the ContainsFold predicate on the "status" field. +func StatusContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldStatus, v)) +} + +// ResponseStatusEQ applies the EQ predicate on the "response_status" field. +func ResponseStatusEQ(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseStatus, v)) +} + +// ResponseStatusNEQ applies the NEQ predicate on the "response_status" field. +func ResponseStatusNEQ(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldResponseStatus, v)) +} + +// ResponseStatusIn applies the In predicate on the "response_status" field. +func ResponseStatusIn(vs ...int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldResponseStatus, vs...)) +} + +// ResponseStatusNotIn applies the NotIn predicate on the "response_status" field. +func ResponseStatusNotIn(vs ...int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldResponseStatus, vs...)) +} + +// ResponseStatusGT applies the GT predicate on the "response_status" field. +func ResponseStatusGT(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldResponseStatus, v)) +} + +// ResponseStatusGTE applies the GTE predicate on the "response_status" field. +func ResponseStatusGTE(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldResponseStatus, v)) +} + +// ResponseStatusLT applies the LT predicate on the "response_status" field. +func ResponseStatusLT(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldResponseStatus, v)) +} + +// ResponseStatusLTE applies the LTE predicate on the "response_status" field. +func ResponseStatusLTE(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldResponseStatus, v)) +} + +// ResponseStatusIsNil applies the IsNil predicate on the "response_status" field. +func ResponseStatusIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldResponseStatus)) +} + +// ResponseStatusNotNil applies the NotNil predicate on the "response_status" field. +func ResponseStatusNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldResponseStatus)) +} + +// ResponseBodyEQ applies the EQ predicate on the "response_body" field. +func ResponseBodyEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseBody, v)) +} + +// ResponseBodyNEQ applies the NEQ predicate on the "response_body" field. +func ResponseBodyNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldResponseBody, v)) +} + +// ResponseBodyIn applies the In predicate on the "response_body" field. +func ResponseBodyIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldResponseBody, vs...)) +} + +// ResponseBodyNotIn applies the NotIn predicate on the "response_body" field. +func ResponseBodyNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldResponseBody, vs...)) +} + +// ResponseBodyGT applies the GT predicate on the "response_body" field. +func ResponseBodyGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldResponseBody, v)) +} + +// ResponseBodyGTE applies the GTE predicate on the "response_body" field. +func ResponseBodyGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldResponseBody, v)) +} + +// ResponseBodyLT applies the LT predicate on the "response_body" field. +func ResponseBodyLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldResponseBody, v)) +} + +// ResponseBodyLTE applies the LTE predicate on the "response_body" field. +func ResponseBodyLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldResponseBody, v)) +} + +// ResponseBodyContains applies the Contains predicate on the "response_body" field. +func ResponseBodyContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldResponseBody, v)) +} + +// ResponseBodyHasPrefix applies the HasPrefix predicate on the "response_body" field. +func ResponseBodyHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldResponseBody, v)) +} + +// ResponseBodyHasSuffix applies the HasSuffix predicate on the "response_body" field. +func ResponseBodyHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldResponseBody, v)) +} + +// ResponseBodyIsNil applies the IsNil predicate on the "response_body" field. +func ResponseBodyIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldResponseBody)) +} + +// ResponseBodyNotNil applies the NotNil predicate on the "response_body" field. +func ResponseBodyNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldResponseBody)) +} + +// ResponseBodyEqualFold applies the EqualFold predicate on the "response_body" field. +func ResponseBodyEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldResponseBody, v)) +} + +// ResponseBodyContainsFold applies the ContainsFold predicate on the "response_body" field. +func ResponseBodyContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldResponseBody, v)) +} + +// ErrorReasonEQ applies the EQ predicate on the "error_reason" field. +func ErrorReasonEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldErrorReason, v)) +} + +// ErrorReasonNEQ applies the NEQ predicate on the "error_reason" field. +func ErrorReasonNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldErrorReason, v)) +} + +// ErrorReasonIn applies the In predicate on the "error_reason" field. +func ErrorReasonIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldErrorReason, vs...)) +} + +// ErrorReasonNotIn applies the NotIn predicate on the "error_reason" field. +func ErrorReasonNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldErrorReason, vs...)) +} + +// ErrorReasonGT applies the GT predicate on the "error_reason" field. +func ErrorReasonGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldErrorReason, v)) +} + +// ErrorReasonGTE applies the GTE predicate on the "error_reason" field. +func ErrorReasonGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldErrorReason, v)) +} + +// ErrorReasonLT applies the LT predicate on the "error_reason" field. +func ErrorReasonLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldErrorReason, v)) +} + +// ErrorReasonLTE applies the LTE predicate on the "error_reason" field. +func ErrorReasonLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldErrorReason, v)) +} + +// ErrorReasonContains applies the Contains predicate on the "error_reason" field. +func ErrorReasonContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldErrorReason, v)) +} + +// ErrorReasonHasPrefix applies the HasPrefix predicate on the "error_reason" field. +func ErrorReasonHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldErrorReason, v)) +} + +// ErrorReasonHasSuffix applies the HasSuffix predicate on the "error_reason" field. +func ErrorReasonHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldErrorReason, v)) +} + +// ErrorReasonIsNil applies the IsNil predicate on the "error_reason" field. +func ErrorReasonIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldErrorReason)) +} + +// ErrorReasonNotNil applies the NotNil predicate on the "error_reason" field. +func ErrorReasonNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldErrorReason)) +} + +// ErrorReasonEqualFold applies the EqualFold predicate on the "error_reason" field. +func ErrorReasonEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldErrorReason, v)) +} + +// ErrorReasonContainsFold applies the ContainsFold predicate on the "error_reason" field. +func ErrorReasonContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldErrorReason, v)) +} + +// LockedUntilEQ applies the EQ predicate on the "locked_until" field. +func LockedUntilEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldLockedUntil, v)) +} + +// LockedUntilNEQ applies the NEQ predicate on the "locked_until" field. +func LockedUntilNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldLockedUntil, v)) +} + +// LockedUntilIn applies the In predicate on the "locked_until" field. +func LockedUntilIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldLockedUntil, vs...)) +} + +// LockedUntilNotIn applies the NotIn predicate on the "locked_until" field. +func LockedUntilNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldLockedUntil, vs...)) +} + +// LockedUntilGT applies the GT predicate on the "locked_until" field. +func LockedUntilGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldLockedUntil, v)) +} + +// LockedUntilGTE applies the GTE predicate on the "locked_until" field. +func LockedUntilGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldLockedUntil, v)) +} + +// LockedUntilLT applies the LT predicate on the "locked_until" field. +func LockedUntilLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldLockedUntil, v)) +} + +// LockedUntilLTE applies the LTE predicate on the "locked_until" field. +func LockedUntilLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldLockedUntil, v)) +} + +// LockedUntilIsNil applies the IsNil predicate on the "locked_until" field. +func LockedUntilIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldLockedUntil)) +} + +// LockedUntilNotNil applies the NotNil predicate on the "locked_until" field. +func LockedUntilNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldLockedUntil)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldExpiresAt, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.IdempotencyRecord) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.IdempotencyRecord) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.IdempotencyRecord) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.NotPredicates(p)) +} diff --git a/backend/ent/idempotencyrecord_create.go b/backend/ent/idempotencyrecord_create.go new file mode 100644 index 00000000..bf4deaf2 --- /dev/null +++ b/backend/ent/idempotencyrecord_create.go @@ -0,0 +1,1132 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" +) + +// IdempotencyRecordCreate is the builder for creating a IdempotencyRecord entity. +type IdempotencyRecordCreate struct { + config + mutation *IdempotencyRecordMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *IdempotencyRecordCreate) SetCreatedAt(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableCreatedAt(v *time.Time) *IdempotencyRecordCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *IdempotencyRecordCreate) SetUpdatedAt(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableUpdatedAt(v *time.Time) *IdempotencyRecordCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetScope sets the "scope" field. +func (_c *IdempotencyRecordCreate) SetScope(v string) *IdempotencyRecordCreate { + _c.mutation.SetScope(v) + return _c +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (_c *IdempotencyRecordCreate) SetIdempotencyKeyHash(v string) *IdempotencyRecordCreate { + _c.mutation.SetIdempotencyKeyHash(v) + return _c +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (_c *IdempotencyRecordCreate) SetRequestFingerprint(v string) *IdempotencyRecordCreate { + _c.mutation.SetRequestFingerprint(v) + return _c +} + +// SetStatus sets the "status" field. +func (_c *IdempotencyRecordCreate) SetStatus(v string) *IdempotencyRecordCreate { + _c.mutation.SetStatus(v) + return _c +} + +// SetResponseStatus sets the "response_status" field. +func (_c *IdempotencyRecordCreate) SetResponseStatus(v int) *IdempotencyRecordCreate { + _c.mutation.SetResponseStatus(v) + return _c +} + +// SetNillableResponseStatus sets the "response_status" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableResponseStatus(v *int) *IdempotencyRecordCreate { + if v != nil { + _c.SetResponseStatus(*v) + } + return _c +} + +// SetResponseBody sets the "response_body" field. +func (_c *IdempotencyRecordCreate) SetResponseBody(v string) *IdempotencyRecordCreate { + _c.mutation.SetResponseBody(v) + return _c +} + +// SetNillableResponseBody sets the "response_body" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableResponseBody(v *string) *IdempotencyRecordCreate { + if v != nil { + _c.SetResponseBody(*v) + } + return _c +} + +// SetErrorReason sets the "error_reason" field. +func (_c *IdempotencyRecordCreate) SetErrorReason(v string) *IdempotencyRecordCreate { + _c.mutation.SetErrorReason(v) + return _c +} + +// SetNillableErrorReason sets the "error_reason" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableErrorReason(v *string) *IdempotencyRecordCreate { + if v != nil { + _c.SetErrorReason(*v) + } + return _c +} + +// SetLockedUntil sets the "locked_until" field. +func (_c *IdempotencyRecordCreate) SetLockedUntil(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetLockedUntil(v) + return _c +} + +// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordCreate { + if v != nil { + _c.SetLockedUntil(*v) + } + return _c +} + +// SetExpiresAt sets the "expires_at" field. +func (_c *IdempotencyRecordCreate) SetExpiresAt(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// Mutation returns the IdempotencyRecordMutation object of the builder. +func (_c *IdempotencyRecordCreate) Mutation() *IdempotencyRecordMutation { + return _c.mutation +} + +// Save creates the IdempotencyRecord in the database. +func (_c *IdempotencyRecordCreate) Save(ctx context.Context) (*IdempotencyRecord, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *IdempotencyRecordCreate) SaveX(ctx context.Context) *IdempotencyRecord { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdempotencyRecordCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdempotencyRecordCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *IdempotencyRecordCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := idempotencyrecord.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := idempotencyrecord.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *IdempotencyRecordCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdempotencyRecord.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdempotencyRecord.updated_at"`)} + } + if _, ok := _c.mutation.Scope(); !ok { + return &ValidationError{Name: "scope", err: errors.New(`ent: missing required field "IdempotencyRecord.scope"`)} + } + if v, ok := _c.mutation.Scope(); ok { + if err := idempotencyrecord.ScopeValidator(v); err != nil { + return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)} + } + } + if _, ok := _c.mutation.IdempotencyKeyHash(); !ok { + return &ValidationError{Name: "idempotency_key_hash", err: errors.New(`ent: missing required field "IdempotencyRecord.idempotency_key_hash"`)} + } + if v, ok := _c.mutation.IdempotencyKeyHash(); ok { + if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil { + return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)} + } + } + if _, ok := _c.mutation.RequestFingerprint(); !ok { + return &ValidationError{Name: "request_fingerprint", err: errors.New(`ent: missing required field "IdempotencyRecord.request_fingerprint"`)} + } + if v, ok := _c.mutation.RequestFingerprint(); ok { + if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil { + return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)} + } + } + if _, ok := _c.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "IdempotencyRecord.status"`)} + } + if v, ok := _c.mutation.Status(); ok { + if err := idempotencyrecord.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)} + } + } + if v, ok := _c.mutation.ErrorReason(); ok { + if err := idempotencyrecord.ErrorReasonValidator(v); err != nil { + return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)} + } + } + if _, ok := _c.mutation.ExpiresAt(); !ok { + return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "IdempotencyRecord.expires_at"`)} + } + return nil +} + +func (_c *IdempotencyRecordCreate) sqlSave(ctx context.Context) (*IdempotencyRecord, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *IdempotencyRecordCreate) createSpec() (*IdempotencyRecord, *sqlgraph.CreateSpec) { + var ( + _node = &IdempotencyRecord{config: _c.config} + _spec = sqlgraph.NewCreateSpec(idempotencyrecord.Table, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Scope(); ok { + _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value) + _node.Scope = value + } + if value, ok := _c.mutation.IdempotencyKeyHash(); ok { + _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value) + _node.IdempotencyKeyHash = value + } + if value, ok := _c.mutation.RequestFingerprint(); ok { + _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value) + _node.RequestFingerprint = value + } + if value, ok := _c.mutation.Status(); ok { + _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value) + _node.Status = value + } + if value, ok := _c.mutation.ResponseStatus(); ok { + _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + _node.ResponseStatus = &value + } + if value, ok := _c.mutation.ResponseBody(); ok { + _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value) + _node.ResponseBody = &value + } + if value, ok := _c.mutation.ErrorReason(); ok { + _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value) + _node.ErrorReason = &value + } + if value, ok := _c.mutation.LockedUntil(); ok { + _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value) + _node.LockedUntil = &value + } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdempotencyRecord.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdempotencyRecordUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdempotencyRecordCreate) OnConflict(opts ...sql.ConflictOption) *IdempotencyRecordUpsertOne { + _c.conflict = opts + return &IdempotencyRecordUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdempotencyRecordCreate) OnConflictColumns(columns ...string) *IdempotencyRecordUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdempotencyRecordUpsertOne{ + create: _c, + } +} + +type ( + // IdempotencyRecordUpsertOne is the builder for "upsert"-ing + // one IdempotencyRecord node. + IdempotencyRecordUpsertOne struct { + create *IdempotencyRecordCreate + } + + // IdempotencyRecordUpsert is the "OnConflict" setter. + IdempotencyRecordUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdempotencyRecordUpsert) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateUpdatedAt() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldUpdatedAt) + return u +} + +// SetScope sets the "scope" field. +func (u *IdempotencyRecordUpsert) SetScope(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldScope, v) + return u +} + +// UpdateScope sets the "scope" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateScope() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldScope) + return u +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (u *IdempotencyRecordUpsert) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldIdempotencyKeyHash, v) + return u +} + +// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldIdempotencyKeyHash) + return u +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (u *IdempotencyRecordUpsert) SetRequestFingerprint(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldRequestFingerprint, v) + return u +} + +// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateRequestFingerprint() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldRequestFingerprint) + return u +} + +// SetStatus sets the "status" field. +func (u *IdempotencyRecordUpsert) SetStatus(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateStatus() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldStatus) + return u +} + +// SetResponseStatus sets the "response_status" field. +func (u *IdempotencyRecordUpsert) SetResponseStatus(v int) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldResponseStatus, v) + return u +} + +// UpdateResponseStatus sets the "response_status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateResponseStatus() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldResponseStatus) + return u +} + +// AddResponseStatus adds v to the "response_status" field. +func (u *IdempotencyRecordUpsert) AddResponseStatus(v int) *IdempotencyRecordUpsert { + u.Add(idempotencyrecord.FieldResponseStatus, v) + return u +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (u *IdempotencyRecordUpsert) ClearResponseStatus() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldResponseStatus) + return u +} + +// SetResponseBody sets the "response_body" field. +func (u *IdempotencyRecordUpsert) SetResponseBody(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldResponseBody, v) + return u +} + +// UpdateResponseBody sets the "response_body" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateResponseBody() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldResponseBody) + return u +} + +// ClearResponseBody clears the value of the "response_body" field. +func (u *IdempotencyRecordUpsert) ClearResponseBody() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldResponseBody) + return u +} + +// SetErrorReason sets the "error_reason" field. +func (u *IdempotencyRecordUpsert) SetErrorReason(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldErrorReason, v) + return u +} + +// UpdateErrorReason sets the "error_reason" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateErrorReason() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldErrorReason) + return u +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (u *IdempotencyRecordUpsert) ClearErrorReason() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldErrorReason) + return u +} + +// SetLockedUntil sets the "locked_until" field. +func (u *IdempotencyRecordUpsert) SetLockedUntil(v time.Time) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldLockedUntil, v) + return u +} + +// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateLockedUntil() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldLockedUntil) + return u +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (u *IdempotencyRecordUpsert) ClearLockedUntil() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldLockedUntil) + return u +} + +// SetExpiresAt sets the "expires_at" field. +func (u *IdempotencyRecordUpsert) SetExpiresAt(v time.Time) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateExpiresAt() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldExpiresAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdempotencyRecordUpsertOne) UpdateNewValues() *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(idempotencyrecord.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdempotencyRecordUpsertOne) Ignore() *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdempotencyRecordUpsertOne) DoNothing() *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdempotencyRecordCreate.OnConflict +// documentation for more info. +func (u *IdempotencyRecordUpsertOne) Update(set func(*IdempotencyRecordUpsert)) *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdempotencyRecordUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdempotencyRecordUpsertOne) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateUpdatedAt() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetScope sets the "scope" field. +func (u *IdempotencyRecordUpsertOne) SetScope(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetScope(v) + }) +} + +// UpdateScope sets the "scope" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateScope() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateScope() + }) +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (u *IdempotencyRecordUpsertOne) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetIdempotencyKeyHash(v) + }) +} + +// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateIdempotencyKeyHash() + }) +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (u *IdempotencyRecordUpsertOne) SetRequestFingerprint(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetRequestFingerprint(v) + }) +} + +// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateRequestFingerprint() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateRequestFingerprint() + }) +} + +// SetStatus sets the "status" field. +func (u *IdempotencyRecordUpsertOne) SetStatus(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateStatus() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateStatus() + }) +} + +// SetResponseStatus sets the "response_status" field. +func (u *IdempotencyRecordUpsertOne) SetResponseStatus(v int) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseStatus(v) + }) +} + +// AddResponseStatus adds v to the "response_status" field. +func (u *IdempotencyRecordUpsertOne) AddResponseStatus(v int) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.AddResponseStatus(v) + }) +} + +// UpdateResponseStatus sets the "response_status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateResponseStatus() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseStatus() + }) +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (u *IdempotencyRecordUpsertOne) ClearResponseStatus() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseStatus() + }) +} + +// SetResponseBody sets the "response_body" field. +func (u *IdempotencyRecordUpsertOne) SetResponseBody(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseBody(v) + }) +} + +// UpdateResponseBody sets the "response_body" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateResponseBody() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseBody() + }) +} + +// ClearResponseBody clears the value of the "response_body" field. +func (u *IdempotencyRecordUpsertOne) ClearResponseBody() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseBody() + }) +} + +// SetErrorReason sets the "error_reason" field. +func (u *IdempotencyRecordUpsertOne) SetErrorReason(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetErrorReason(v) + }) +} + +// UpdateErrorReason sets the "error_reason" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateErrorReason() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateErrorReason() + }) +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (u *IdempotencyRecordUpsertOne) ClearErrorReason() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearErrorReason() + }) +} + +// SetLockedUntil sets the "locked_until" field. +func (u *IdempotencyRecordUpsertOne) SetLockedUntil(v time.Time) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetLockedUntil(v) + }) +} + +// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateLockedUntil() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateLockedUntil() + }) +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (u *IdempotencyRecordUpsertOne) ClearLockedUntil() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearLockedUntil() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *IdempotencyRecordUpsertOne) SetExpiresAt(v time.Time) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateExpiresAt() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateExpiresAt() + }) +} + +// Exec executes the query. +func (u *IdempotencyRecordUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdempotencyRecordCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdempotencyRecordUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *IdempotencyRecordUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *IdempotencyRecordUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// IdempotencyRecordCreateBulk is the builder for creating many IdempotencyRecord entities in bulk. +type IdempotencyRecordCreateBulk struct { + config + err error + builders []*IdempotencyRecordCreate + conflict []sql.ConflictOption +} + +// Save creates the IdempotencyRecord entities in the database. +func (_c *IdempotencyRecordCreateBulk) Save(ctx context.Context) ([]*IdempotencyRecord, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*IdempotencyRecord, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*IdempotencyRecordMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *IdempotencyRecordCreateBulk) SaveX(ctx context.Context) []*IdempotencyRecord { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdempotencyRecordCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdempotencyRecordCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdempotencyRecord.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdempotencyRecordUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdempotencyRecordCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdempotencyRecordUpsertBulk { + _c.conflict = opts + return &IdempotencyRecordUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdempotencyRecordCreateBulk) OnConflictColumns(columns ...string) *IdempotencyRecordUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdempotencyRecordUpsertBulk{ + create: _c, + } +} + +// IdempotencyRecordUpsertBulk is the builder for "upsert"-ing +// a bulk of IdempotencyRecord nodes. +type IdempotencyRecordUpsertBulk struct { + create *IdempotencyRecordCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdempotencyRecordUpsertBulk) UpdateNewValues() *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(idempotencyrecord.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdempotencyRecordUpsertBulk) Ignore() *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdempotencyRecordUpsertBulk) DoNothing() *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdempotencyRecordCreateBulk.OnConflict +// documentation for more info. +func (u *IdempotencyRecordUpsertBulk) Update(set func(*IdempotencyRecordUpsert)) *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdempotencyRecordUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdempotencyRecordUpsertBulk) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateUpdatedAt() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetScope sets the "scope" field. +func (u *IdempotencyRecordUpsertBulk) SetScope(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetScope(v) + }) +} + +// UpdateScope sets the "scope" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateScope() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateScope() + }) +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (u *IdempotencyRecordUpsertBulk) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetIdempotencyKeyHash(v) + }) +} + +// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateIdempotencyKeyHash() + }) +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (u *IdempotencyRecordUpsertBulk) SetRequestFingerprint(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetRequestFingerprint(v) + }) +} + +// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateRequestFingerprint() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateRequestFingerprint() + }) +} + +// SetStatus sets the "status" field. +func (u *IdempotencyRecordUpsertBulk) SetStatus(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateStatus() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateStatus() + }) +} + +// SetResponseStatus sets the "response_status" field. +func (u *IdempotencyRecordUpsertBulk) SetResponseStatus(v int) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseStatus(v) + }) +} + +// AddResponseStatus adds v to the "response_status" field. +func (u *IdempotencyRecordUpsertBulk) AddResponseStatus(v int) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.AddResponseStatus(v) + }) +} + +// UpdateResponseStatus sets the "response_status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateResponseStatus() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseStatus() + }) +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (u *IdempotencyRecordUpsertBulk) ClearResponseStatus() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseStatus() + }) +} + +// SetResponseBody sets the "response_body" field. +func (u *IdempotencyRecordUpsertBulk) SetResponseBody(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseBody(v) + }) +} + +// UpdateResponseBody sets the "response_body" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateResponseBody() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseBody() + }) +} + +// ClearResponseBody clears the value of the "response_body" field. +func (u *IdempotencyRecordUpsertBulk) ClearResponseBody() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseBody() + }) +} + +// SetErrorReason sets the "error_reason" field. +func (u *IdempotencyRecordUpsertBulk) SetErrorReason(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetErrorReason(v) + }) +} + +// UpdateErrorReason sets the "error_reason" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateErrorReason() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateErrorReason() + }) +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (u *IdempotencyRecordUpsertBulk) ClearErrorReason() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearErrorReason() + }) +} + +// SetLockedUntil sets the "locked_until" field. +func (u *IdempotencyRecordUpsertBulk) SetLockedUntil(v time.Time) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetLockedUntil(v) + }) +} + +// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateLockedUntil() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateLockedUntil() + }) +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (u *IdempotencyRecordUpsertBulk) ClearLockedUntil() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearLockedUntil() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *IdempotencyRecordUpsertBulk) SetExpiresAt(v time.Time) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateExpiresAt() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateExpiresAt() + }) +} + +// Exec executes the query. +func (u *IdempotencyRecordUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdempotencyRecordCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdempotencyRecordCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdempotencyRecordUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/idempotencyrecord_delete.go b/backend/ent/idempotencyrecord_delete.go new file mode 100644 index 00000000..f5c87559 --- /dev/null +++ b/backend/ent/idempotencyrecord_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdempotencyRecordDelete is the builder for deleting a IdempotencyRecord entity. +type IdempotencyRecordDelete struct { + config + hooks []Hook + mutation *IdempotencyRecordMutation +} + +// Where appends a list predicates to the IdempotencyRecordDelete builder. +func (_d *IdempotencyRecordDelete) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *IdempotencyRecordDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdempotencyRecordDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *IdempotencyRecordDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(idempotencyrecord.Table, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// IdempotencyRecordDeleteOne is the builder for deleting a single IdempotencyRecord entity. +type IdempotencyRecordDeleteOne struct { + _d *IdempotencyRecordDelete +} + +// Where appends a list predicates to the IdempotencyRecordDelete builder. +func (_d *IdempotencyRecordDeleteOne) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *IdempotencyRecordDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{idempotencyrecord.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdempotencyRecordDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/idempotencyrecord_query.go b/backend/ent/idempotencyrecord_query.go new file mode 100644 index 00000000..fbba4dfa --- /dev/null +++ b/backend/ent/idempotencyrecord_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdempotencyRecordQuery is the builder for querying IdempotencyRecord entities. +type IdempotencyRecordQuery struct { + config + ctx *QueryContext + order []idempotencyrecord.OrderOption + inters []Interceptor + predicates []predicate.IdempotencyRecord + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the IdempotencyRecordQuery builder. +func (_q *IdempotencyRecordQuery) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *IdempotencyRecordQuery) Limit(limit int) *IdempotencyRecordQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *IdempotencyRecordQuery) Offset(offset int) *IdempotencyRecordQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *IdempotencyRecordQuery) Unique(unique bool) *IdempotencyRecordQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *IdempotencyRecordQuery) Order(o ...idempotencyrecord.OrderOption) *IdempotencyRecordQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first IdempotencyRecord entity from the query. +// Returns a *NotFoundError when no IdempotencyRecord was found. +func (_q *IdempotencyRecordQuery) First(ctx context.Context) (*IdempotencyRecord, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{idempotencyrecord.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) FirstX(ctx context.Context) *IdempotencyRecord { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first IdempotencyRecord ID from the query. +// Returns a *NotFoundError when no IdempotencyRecord ID was found. +func (_q *IdempotencyRecordQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{idempotencyrecord.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single IdempotencyRecord entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one IdempotencyRecord entity is found. +// Returns a *NotFoundError when no IdempotencyRecord entities are found. +func (_q *IdempotencyRecordQuery) Only(ctx context.Context) (*IdempotencyRecord, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{idempotencyrecord.Label} + default: + return nil, &NotSingularError{idempotencyrecord.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) OnlyX(ctx context.Context) *IdempotencyRecord { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only IdempotencyRecord ID in the query. +// Returns a *NotSingularError when more than one IdempotencyRecord ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *IdempotencyRecordQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{idempotencyrecord.Label} + default: + err = &NotSingularError{idempotencyrecord.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of IdempotencyRecords. +func (_q *IdempotencyRecordQuery) All(ctx context.Context) ([]*IdempotencyRecord, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*IdempotencyRecord, *IdempotencyRecordQuery]() + return withInterceptors[[]*IdempotencyRecord](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) AllX(ctx context.Context) []*IdempotencyRecord { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of IdempotencyRecord IDs. +func (_q *IdempotencyRecordQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(idempotencyrecord.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *IdempotencyRecordQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*IdempotencyRecordQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *IdempotencyRecordQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the IdempotencyRecordQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *IdempotencyRecordQuery) Clone() *IdempotencyRecordQuery { + if _q == nil { + return nil + } + return &IdempotencyRecordQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]idempotencyrecord.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.IdempotencyRecord{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.IdempotencyRecord.Query(). +// GroupBy(idempotencyrecord.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *IdempotencyRecordQuery) GroupBy(field string, fields ...string) *IdempotencyRecordGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &IdempotencyRecordGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = idempotencyrecord.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.IdempotencyRecord.Query(). +// Select(idempotencyrecord.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *IdempotencyRecordQuery) Select(fields ...string) *IdempotencyRecordSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &IdempotencyRecordSelect{IdempotencyRecordQuery: _q} + sbuild.label = idempotencyrecord.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a IdempotencyRecordSelect configured with the given aggregations. +func (_q *IdempotencyRecordQuery) Aggregate(fns ...AggregateFunc) *IdempotencyRecordSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *IdempotencyRecordQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !idempotencyrecord.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *IdempotencyRecordQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdempotencyRecord, error) { + var ( + nodes = []*IdempotencyRecord{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*IdempotencyRecord).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &IdempotencyRecord{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *IdempotencyRecordQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *IdempotencyRecordQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, idempotencyrecord.FieldID) + for i := range fields { + if fields[i] != idempotencyrecord.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *IdempotencyRecordQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(idempotencyrecord.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = idempotencyrecord.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *IdempotencyRecordQuery) ForUpdate(opts ...sql.LockOption) *IdempotencyRecordQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *IdempotencyRecordQuery) ForShare(opts ...sql.LockOption) *IdempotencyRecordQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// IdempotencyRecordGroupBy is the group-by builder for IdempotencyRecord entities. +type IdempotencyRecordGroupBy struct { + selector + build *IdempotencyRecordQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *IdempotencyRecordGroupBy) Aggregate(fns ...AggregateFunc) *IdempotencyRecordGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *IdempotencyRecordGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdempotencyRecordQuery, *IdempotencyRecordGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *IdempotencyRecordGroupBy) sqlScan(ctx context.Context, root *IdempotencyRecordQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// IdempotencyRecordSelect is the builder for selecting fields of IdempotencyRecord entities. +type IdempotencyRecordSelect struct { + *IdempotencyRecordQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *IdempotencyRecordSelect) Aggregate(fns ...AggregateFunc) *IdempotencyRecordSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *IdempotencyRecordSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdempotencyRecordQuery, *IdempotencyRecordSelect](ctx, _s.IdempotencyRecordQuery, _s, _s.inters, v) +} + +func (_s *IdempotencyRecordSelect) sqlScan(ctx context.Context, root *IdempotencyRecordQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/idempotencyrecord_update.go b/backend/ent/idempotencyrecord_update.go new file mode 100644 index 00000000..f839e5c0 --- /dev/null +++ b/backend/ent/idempotencyrecord_update.go @@ -0,0 +1,676 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdempotencyRecordUpdate is the builder for updating IdempotencyRecord entities. +type IdempotencyRecordUpdate struct { + config + hooks []Hook + mutation *IdempotencyRecordMutation +} + +// Where appends a list predicates to the IdempotencyRecordUpdate builder. +func (_u *IdempotencyRecordUpdate) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdempotencyRecordUpdate) SetUpdatedAt(v time.Time) *IdempotencyRecordUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetScope sets the "scope" field. +func (_u *IdempotencyRecordUpdate) SetScope(v string) *IdempotencyRecordUpdate { + _u.mutation.SetScope(v) + return _u +} + +// SetNillableScope sets the "scope" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableScope(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetScope(*v) + } + return _u +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (_u *IdempotencyRecordUpdate) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpdate { + _u.mutation.SetIdempotencyKeyHash(v) + return _u +} + +// SetNillableIdempotencyKeyHash sets the "idempotency_key_hash" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableIdempotencyKeyHash(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetIdempotencyKeyHash(*v) + } + return _u +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (_u *IdempotencyRecordUpdate) SetRequestFingerprint(v string) *IdempotencyRecordUpdate { + _u.mutation.SetRequestFingerprint(v) + return _u +} + +// SetNillableRequestFingerprint sets the "request_fingerprint" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableRequestFingerprint(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetRequestFingerprint(*v) + } + return _u +} + +// SetStatus sets the "status" field. +func (_u *IdempotencyRecordUpdate) SetStatus(v string) *IdempotencyRecordUpdate { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableStatus(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetResponseStatus sets the "response_status" field. +func (_u *IdempotencyRecordUpdate) SetResponseStatus(v int) *IdempotencyRecordUpdate { + _u.mutation.ResetResponseStatus() + _u.mutation.SetResponseStatus(v) + return _u +} + +// SetNillableResponseStatus sets the "response_status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableResponseStatus(v *int) *IdempotencyRecordUpdate { + if v != nil { + _u.SetResponseStatus(*v) + } + return _u +} + +// AddResponseStatus adds value to the "response_status" field. +func (_u *IdempotencyRecordUpdate) AddResponseStatus(v int) *IdempotencyRecordUpdate { + _u.mutation.AddResponseStatus(v) + return _u +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (_u *IdempotencyRecordUpdate) ClearResponseStatus() *IdempotencyRecordUpdate { + _u.mutation.ClearResponseStatus() + return _u +} + +// SetResponseBody sets the "response_body" field. +func (_u *IdempotencyRecordUpdate) SetResponseBody(v string) *IdempotencyRecordUpdate { + _u.mutation.SetResponseBody(v) + return _u +} + +// SetNillableResponseBody sets the "response_body" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableResponseBody(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetResponseBody(*v) + } + return _u +} + +// ClearResponseBody clears the value of the "response_body" field. +func (_u *IdempotencyRecordUpdate) ClearResponseBody() *IdempotencyRecordUpdate { + _u.mutation.ClearResponseBody() + return _u +} + +// SetErrorReason sets the "error_reason" field. +func (_u *IdempotencyRecordUpdate) SetErrorReason(v string) *IdempotencyRecordUpdate { + _u.mutation.SetErrorReason(v) + return _u +} + +// SetNillableErrorReason sets the "error_reason" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableErrorReason(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetErrorReason(*v) + } + return _u +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (_u *IdempotencyRecordUpdate) ClearErrorReason() *IdempotencyRecordUpdate { + _u.mutation.ClearErrorReason() + return _u +} + +// SetLockedUntil sets the "locked_until" field. +func (_u *IdempotencyRecordUpdate) SetLockedUntil(v time.Time) *IdempotencyRecordUpdate { + _u.mutation.SetLockedUntil(v) + return _u +} + +// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordUpdate { + if v != nil { + _u.SetLockedUntil(*v) + } + return _u +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (_u *IdempotencyRecordUpdate) ClearLockedUntil() *IdempotencyRecordUpdate { + _u.mutation.ClearLockedUntil() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *IdempotencyRecordUpdate) SetExpiresAt(v time.Time) *IdempotencyRecordUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableExpiresAt(v *time.Time) *IdempotencyRecordUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// Mutation returns the IdempotencyRecordMutation object of the builder. +func (_u *IdempotencyRecordUpdate) Mutation() *IdempotencyRecordMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *IdempotencyRecordUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdempotencyRecordUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *IdempotencyRecordUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdempotencyRecordUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdempotencyRecordUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := idempotencyrecord.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdempotencyRecordUpdate) check() error { + if v, ok := _u.mutation.Scope(); ok { + if err := idempotencyrecord.ScopeValidator(v); err != nil { + return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)} + } + } + if v, ok := _u.mutation.IdempotencyKeyHash(); ok { + if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil { + return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)} + } + } + if v, ok := _u.mutation.RequestFingerprint(); ok { + if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil { + return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)} + } + } + if v, ok := _u.mutation.Status(); ok { + if err := idempotencyrecord.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)} + } + } + if v, ok := _u.mutation.ErrorReason(); ok { + if err := idempotencyrecord.ErrorReasonValidator(v); err != nil { + return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)} + } + } + return nil +} + +func (_u *IdempotencyRecordUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Scope(); ok { + _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value) + } + if value, ok := _u.mutation.IdempotencyKeyHash(); ok { + _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value) + } + if value, ok := _u.mutation.RequestFingerprint(); ok { + _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value) + } + if value, ok := _u.mutation.ResponseStatus(); ok { + _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseStatus(); ok { + _spec.AddField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if _u.mutation.ResponseStatusCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseStatus, field.TypeInt) + } + if value, ok := _u.mutation.ResponseBody(); ok { + _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value) + } + if _u.mutation.ResponseBodyCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseBody, field.TypeString) + } + if value, ok := _u.mutation.ErrorReason(); ok { + _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value) + } + if _u.mutation.ErrorReasonCleared() { + _spec.ClearField(idempotencyrecord.FieldErrorReason, field.TypeString) + } + if value, ok := _u.mutation.LockedUntil(); ok { + _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value) + } + if _u.mutation.LockedUntilCleared() { + _spec.ClearField(idempotencyrecord.FieldLockedUntil, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{idempotencyrecord.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// IdempotencyRecordUpdateOne is the builder for updating a single IdempotencyRecord entity. +type IdempotencyRecordUpdateOne struct { + config + fields []string + hooks []Hook + mutation *IdempotencyRecordMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdempotencyRecordUpdateOne) SetUpdatedAt(v time.Time) *IdempotencyRecordUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetScope sets the "scope" field. +func (_u *IdempotencyRecordUpdateOne) SetScope(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetScope(v) + return _u +} + +// SetNillableScope sets the "scope" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableScope(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetScope(*v) + } + return _u +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (_u *IdempotencyRecordUpdateOne) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetIdempotencyKeyHash(v) + return _u +} + +// SetNillableIdempotencyKeyHash sets the "idempotency_key_hash" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableIdempotencyKeyHash(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetIdempotencyKeyHash(*v) + } + return _u +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (_u *IdempotencyRecordUpdateOne) SetRequestFingerprint(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetRequestFingerprint(v) + return _u +} + +// SetNillableRequestFingerprint sets the "request_fingerprint" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableRequestFingerprint(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetRequestFingerprint(*v) + } + return _u +} + +// SetStatus sets the "status" field. +func (_u *IdempotencyRecordUpdateOne) SetStatus(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableStatus(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetResponseStatus sets the "response_status" field. +func (_u *IdempotencyRecordUpdateOne) SetResponseStatus(v int) *IdempotencyRecordUpdateOne { + _u.mutation.ResetResponseStatus() + _u.mutation.SetResponseStatus(v) + return _u +} + +// SetNillableResponseStatus sets the "response_status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableResponseStatus(v *int) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetResponseStatus(*v) + } + return _u +} + +// AddResponseStatus adds value to the "response_status" field. +func (_u *IdempotencyRecordUpdateOne) AddResponseStatus(v int) *IdempotencyRecordUpdateOne { + _u.mutation.AddResponseStatus(v) + return _u +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (_u *IdempotencyRecordUpdateOne) ClearResponseStatus() *IdempotencyRecordUpdateOne { + _u.mutation.ClearResponseStatus() + return _u +} + +// SetResponseBody sets the "response_body" field. +func (_u *IdempotencyRecordUpdateOne) SetResponseBody(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetResponseBody(v) + return _u +} + +// SetNillableResponseBody sets the "response_body" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableResponseBody(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetResponseBody(*v) + } + return _u +} + +// ClearResponseBody clears the value of the "response_body" field. +func (_u *IdempotencyRecordUpdateOne) ClearResponseBody() *IdempotencyRecordUpdateOne { + _u.mutation.ClearResponseBody() + return _u +} + +// SetErrorReason sets the "error_reason" field. +func (_u *IdempotencyRecordUpdateOne) SetErrorReason(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetErrorReason(v) + return _u +} + +// SetNillableErrorReason sets the "error_reason" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableErrorReason(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetErrorReason(*v) + } + return _u +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (_u *IdempotencyRecordUpdateOne) ClearErrorReason() *IdempotencyRecordUpdateOne { + _u.mutation.ClearErrorReason() + return _u +} + +// SetLockedUntil sets the "locked_until" field. +func (_u *IdempotencyRecordUpdateOne) SetLockedUntil(v time.Time) *IdempotencyRecordUpdateOne { + _u.mutation.SetLockedUntil(v) + return _u +} + +// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetLockedUntil(*v) + } + return _u +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (_u *IdempotencyRecordUpdateOne) ClearLockedUntil() *IdempotencyRecordUpdateOne { + _u.mutation.ClearLockedUntil() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *IdempotencyRecordUpdateOne) SetExpiresAt(v time.Time) *IdempotencyRecordUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableExpiresAt(v *time.Time) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// Mutation returns the IdempotencyRecordMutation object of the builder. +func (_u *IdempotencyRecordUpdateOne) Mutation() *IdempotencyRecordMutation { + return _u.mutation +} + +// Where appends a list predicates to the IdempotencyRecordUpdate builder. +func (_u *IdempotencyRecordUpdateOne) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *IdempotencyRecordUpdateOne) Select(field string, fields ...string) *IdempotencyRecordUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated IdempotencyRecord entity. +func (_u *IdempotencyRecordUpdateOne) Save(ctx context.Context) (*IdempotencyRecord, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdempotencyRecordUpdateOne) SaveX(ctx context.Context) *IdempotencyRecord { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *IdempotencyRecordUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdempotencyRecordUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdempotencyRecordUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := idempotencyrecord.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdempotencyRecordUpdateOne) check() error { + if v, ok := _u.mutation.Scope(); ok { + if err := idempotencyrecord.ScopeValidator(v); err != nil { + return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)} + } + } + if v, ok := _u.mutation.IdempotencyKeyHash(); ok { + if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil { + return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)} + } + } + if v, ok := _u.mutation.RequestFingerprint(); ok { + if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil { + return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)} + } + } + if v, ok := _u.mutation.Status(); ok { + if err := idempotencyrecord.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)} + } + } + if v, ok := _u.mutation.ErrorReason(); ok { + if err := idempotencyrecord.ErrorReasonValidator(v); err != nil { + return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)} + } + } + return nil +} + +func (_u *IdempotencyRecordUpdateOne) sqlSave(ctx context.Context) (_node *IdempotencyRecord, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdempotencyRecord.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, idempotencyrecord.FieldID) + for _, f := range fields { + if !idempotencyrecord.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != idempotencyrecord.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Scope(); ok { + _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value) + } + if value, ok := _u.mutation.IdempotencyKeyHash(); ok { + _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value) + } + if value, ok := _u.mutation.RequestFingerprint(); ok { + _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value) + } + if value, ok := _u.mutation.ResponseStatus(); ok { + _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseStatus(); ok { + _spec.AddField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if _u.mutation.ResponseStatusCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseStatus, field.TypeInt) + } + if value, ok := _u.mutation.ResponseBody(); ok { + _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value) + } + if _u.mutation.ResponseBodyCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseBody, field.TypeString) + } + if value, ok := _u.mutation.ErrorReason(); ok { + _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value) + } + if _u.mutation.ErrorReasonCleared() { + _spec.ClearField(idempotencyrecord.FieldErrorReason, field.TypeString) + } + if value, ok := _u.mutation.LockedUntil(); ok { + _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value) + } + if _u.mutation.LockedUntilCleared() { + _spec.ClearField(idempotencyrecord.FieldLockedUntil, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value) + } + _node = &IdempotencyRecord{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{idempotencyrecord.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 8ee42db3..e7746402 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -15,11 +15,13 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -275,6 +277,33 @@ func (f TraverseGroup) Traverse(ctx context.Context, q ent.Query) error { return fmt.Errorf("unexpected query type %T. expect *ent.GroupQuery", q) } +// The IdempotencyRecordFunc type is an adapter to allow the use of ordinary function as a Querier. +type IdempotencyRecordFunc func(context.Context, *ent.IdempotencyRecordQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f IdempotencyRecordFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.IdempotencyRecordQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) +} + +// The TraverseIdempotencyRecord type is an adapter to allow the use of ordinary function as Traverser. +type TraverseIdempotencyRecord func(context.Context, *ent.IdempotencyRecordQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseIdempotencyRecord) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.IdempotencyRecordQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier. type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error) @@ -383,6 +412,33 @@ func (f TraverseRedeemCode) Traverse(ctx context.Context, q ent.Query) error { return fmt.Errorf("unexpected query type %T. expect *ent.RedeemCodeQuery", q) } +// The SecuritySecretFunc type is an adapter to allow the use of ordinary function as a Querier. +type SecuritySecretFunc func(context.Context, *ent.SecuritySecretQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f SecuritySecretFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.SecuritySecretQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.SecuritySecretQuery", q) +} + +// The TraverseSecuritySecret type is an adapter to allow the use of ordinary function as Traverser. +type TraverseSecuritySecret func(context.Context, *ent.SecuritySecretQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseSecuritySecret) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseSecuritySecret) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.SecuritySecretQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.SecuritySecretQuery", q) +} + // The SettingFunc type is an adapter to allow the use of ordinary function as a Querier. type SettingFunc func(context.Context, *ent.SettingQuery) (ent.Value, error) @@ -616,6 +672,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil case *ent.GroupQuery: return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil + case *ent.IdempotencyRecordQuery: + return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil case *ent.PromoCodeQuery: return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil case *ent.PromoCodeUsageQuery: @@ -624,6 +682,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.ProxyQuery, predicate.Proxy, proxy.OrderOption]{typ: ent.TypeProxy, tq: q}, nil case *ent.RedeemCodeQuery: return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil + case *ent.SecuritySecretQuery: + return &query[*ent.SecuritySecretQuery, predicate.SecuritySecret, securitysecret.OrderOption]{typ: ent.TypeSecuritySecret, tq: q}, nil case *ent.SettingQuery: return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil case *ent.UsageCleanupTaskQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index f9e90d73..769dddce 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -18,6 +18,7 @@ var ( {Name: "key", Type: field.TypeString, Unique: true, Size: 128}, {Name: "name", Type: field.TypeString, Size: 100}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, + {Name: "last_used_at", Type: field.TypeTime, Nullable: true}, {Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true}, {Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true}, {Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, @@ -34,13 +35,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "api_keys_groups_api_keys", - Columns: []*schema.Column{APIKeysColumns[12]}, + Columns: []*schema.Column{APIKeysColumns[13]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "api_keys_users_api_keys", - Columns: []*schema.Column{APIKeysColumns[13]}, + Columns: []*schema.Column{APIKeysColumns[14]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -49,12 +50,12 @@ var ( { Name: "apikey_user_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[13]}, + Columns: []*schema.Column{APIKeysColumns[14]}, }, { Name: "apikey_group_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[12]}, + Columns: []*schema.Column{APIKeysColumns[13]}, }, { Name: "apikey_status", @@ -66,15 +67,20 @@ var ( Unique: false, Columns: []*schema.Column{APIKeysColumns[3]}, }, + { + Name: "apikey_last_used_at", + Unique: false, + Columns: []*schema.Column{APIKeysColumns[7]}, + }, { Name: "apikey_quota_quota_used", Unique: false, - Columns: []*schema.Column{APIKeysColumns[9], APIKeysColumns[10]}, + Columns: []*schema.Column{APIKeysColumns[10], APIKeysColumns[11]}, }, { Name: "apikey_expires_at", Unique: false, - Columns: []*schema.Column{APIKeysColumns[11]}, + Columns: []*schema.Column{APIKeysColumns[12]}, }, }, } @@ -102,6 +108,8 @@ var ( {Name: "rate_limited_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "rate_limit_reset_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "overload_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "temp_unschedulable_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "temp_unschedulable_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "session_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "session_window_end", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "session_window_status", Type: field.TypeString, Nullable: true, Size: 20}, @@ -115,7 +123,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "accounts_proxies_proxy", - Columns: []*schema.Column{AccountsColumns[25]}, + Columns: []*schema.Column{AccountsColumns[27]}, RefColumns: []*schema.Column{ProxiesColumns[0]}, OnDelete: schema.SetNull, }, @@ -139,7 +147,7 @@ var ( { Name: "account_proxy_id", Unique: false, - Columns: []*schema.Column{AccountsColumns[25]}, + Columns: []*schema.Column{AccountsColumns[27]}, }, { Name: "account_priority", @@ -171,6 +179,16 @@ var ( Unique: false, Columns: []*schema.Column{AccountsColumns[21]}, }, + { + Name: "account_platform_priority", + Unique: false, + Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]}, + }, + { + Name: "account_priority_status", + Unique: false, + Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]}, + }, { Name: "account_deleted_at", Unique: false, @@ -325,6 +343,7 @@ var ( {Name: "response_code", Type: field.TypeInt, Nullable: true}, {Name: "passthrough_body", Type: field.TypeBool, Default: true}, {Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647}, + {Name: "skip_monitoring", Type: field.TypeBool, Default: false}, {Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647}, } // ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table. @@ -365,6 +384,11 @@ var ( {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true}, @@ -372,6 +396,7 @@ var ( {Name: "model_routing_enabled", Type: field.TypeBool, Default: false}, {Name: "mcp_xml_inject", Type: field.TypeBool, Default: true}, {Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "sort_order", Type: field.TypeInt, Default: 0}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ @@ -404,6 +429,49 @@ var ( Unique: false, Columns: []*schema.Column{GroupsColumns[3]}, }, + { + Name: "group_sort_order", + Unique: false, + Columns: []*schema.Column{GroupsColumns[30]}, + }, + }, + } + // IdempotencyRecordsColumns holds the columns for the "idempotency_records" table. + IdempotencyRecordsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "scope", Type: field.TypeString, Size: 128}, + {Name: "idempotency_key_hash", Type: field.TypeString, Size: 64}, + {Name: "request_fingerprint", Type: field.TypeString, Size: 64}, + {Name: "status", Type: field.TypeString, Size: 32}, + {Name: "response_status", Type: field.TypeInt, Nullable: true}, + {Name: "response_body", Type: field.TypeString, Nullable: true}, + {Name: "error_reason", Type: field.TypeString, Nullable: true, Size: 128}, + {Name: "locked_until", Type: field.TypeTime, Nullable: true}, + {Name: "expires_at", Type: field.TypeTime}, + } + // IdempotencyRecordsTable holds the schema information for the "idempotency_records" table. + IdempotencyRecordsTable = &schema.Table{ + Name: "idempotency_records", + Columns: IdempotencyRecordsColumns, + PrimaryKey: []*schema.Column{IdempotencyRecordsColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "idempotencyrecord_scope_idempotency_key_hash", + Unique: true, + Columns: []*schema.Column{IdempotencyRecordsColumns[3], IdempotencyRecordsColumns[4]}, + }, + { + Name: "idempotencyrecord_expires_at", + Unique: false, + Columns: []*schema.Column{IdempotencyRecordsColumns[11]}, + }, + { + Name: "idempotencyrecord_status_locked_until", + Unique: false, + Columns: []*schema.Column{IdempotencyRecordsColumns[6], IdempotencyRecordsColumns[10]}, + }, }, } // PromoCodesColumns holds the columns for the "promo_codes" table. @@ -565,6 +633,20 @@ var ( }, }, } + // SecuritySecretsColumns holds the columns for the "security_secrets" table. + SecuritySecretsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "key", Type: field.TypeString, Unique: true, Size: 100}, + {Name: "value", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + } + // SecuritySecretsTable holds the schema information for the "security_secrets" table. + SecuritySecretsTable = &schema.Table{ + Name: "security_secrets", + Columns: SecuritySecretsColumns, + PrimaryKey: []*schema.Column{SecuritySecretsColumns[0]}, + } // SettingsColumns holds the columns for the "settings" table. SettingsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -643,6 +725,8 @@ var ( {Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45}, {Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, + {Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16}, + {Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "api_key_id", Type: field.TypeInt64}, {Name: "account_id", Type: field.TypeInt64}, @@ -658,31 +742,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -691,32 +775,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, }, { Name: "usagelog_model", @@ -731,12 +815,17 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]}, + }, + { + Name: "usagelog_group_id_created_at", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]}, }, }, } @@ -757,6 +846,8 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, + {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0}, + {Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ @@ -962,6 +1053,11 @@ var ( Unique: false, Columns: []*schema.Column{UserSubscriptionsColumns[5]}, }, + { + Name: "usersubscription_user_id_status_expires_at", + Unique: false, + Columns: []*schema.Column{UserSubscriptionsColumns[16], UserSubscriptionsColumns[6], UserSubscriptionsColumns[5]}, + }, { Name: "usersubscription_assigned_by", Unique: false, @@ -988,10 +1084,12 @@ var ( AnnouncementReadsTable, ErrorPassthroughRulesTable, GroupsTable, + IdempotencyRecordsTable, PromoCodesTable, PromoCodeUsagesTable, ProxiesTable, RedeemCodesTable, + SecuritySecretsTable, SettingsTable, UsageCleanupTasksTable, UsageLogsTable, @@ -1032,6 +1130,9 @@ func init() { GroupsTable.Annotation = &entsql.Annotation{ Table: "groups", } + IdempotencyRecordsTable.Annotation = &entsql.Annotation{ + Table: "idempotency_records", + } PromoCodesTable.Annotation = &entsql.Annotation{ Table: "promo_codes", } @@ -1048,6 +1149,9 @@ func init() { RedeemCodesTable.Annotation = &entsql.Annotation{ Table: "redeem_codes", } + SecuritySecretsTable.Annotation = &entsql.Annotation{ + Table: "security_secrets", + } SettingsTable.Annotation = &entsql.Annotation{ Table: "settings", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 5c182dea..823cd389 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -19,11 +19,13 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -51,10 +53,12 @@ const ( TypeAnnouncementRead = "AnnouncementRead" TypeErrorPassthroughRule = "ErrorPassthroughRule" TypeGroup = "Group" + TypeIdempotencyRecord = "IdempotencyRecord" TypePromoCode = "PromoCode" TypePromoCodeUsage = "PromoCodeUsage" TypeProxy = "Proxy" TypeRedeemCode = "RedeemCode" + TypeSecuritySecret = "SecuritySecret" TypeSetting = "Setting" TypeUsageCleanupTask = "UsageCleanupTask" TypeUsageLog = "UsageLog" @@ -77,6 +81,7 @@ type APIKeyMutation struct { key *string name *string status *string + last_used_at *time.Time ip_whitelist *[]string appendip_whitelist []string ip_blacklist *[]string @@ -511,6 +516,55 @@ func (m *APIKeyMutation) ResetStatus() { m.status = nil } +// SetLastUsedAt sets the "last_used_at" field. +func (m *APIKeyMutation) SetLastUsedAt(t time.Time) { + m.last_used_at = &t +} + +// LastUsedAt returns the value of the "last_used_at" field in the mutation. +func (m *APIKeyMutation) LastUsedAt() (r time.Time, exists bool) { + v := m.last_used_at + if v == nil { + return + } + return *v, true +} + +// OldLastUsedAt returns the old "last_used_at" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldLastUsedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastUsedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastUsedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastUsedAt: %w", err) + } + return oldValue.LastUsedAt, nil +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (m *APIKeyMutation) ClearLastUsedAt() { + m.last_used_at = nil + m.clearedFields[apikey.FieldLastUsedAt] = struct{}{} +} + +// LastUsedAtCleared returns if the "last_used_at" field was cleared in this mutation. +func (m *APIKeyMutation) LastUsedAtCleared() bool { + _, ok := m.clearedFields[apikey.FieldLastUsedAt] + return ok +} + +// ResetLastUsedAt resets all changes to the "last_used_at" field. +func (m *APIKeyMutation) ResetLastUsedAt() { + m.last_used_at = nil + delete(m.clearedFields, apikey.FieldLastUsedAt) +} + // SetIPWhitelist sets the "ip_whitelist" field. func (m *APIKeyMutation) SetIPWhitelist(s []string) { m.ip_whitelist = &s @@ -944,7 +998,7 @@ func (m *APIKeyMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *APIKeyMutation) Fields() []string { - fields := make([]string, 0, 13) + fields := make([]string, 0, 14) if m.created_at != nil { fields = append(fields, apikey.FieldCreatedAt) } @@ -969,6 +1023,9 @@ func (m *APIKeyMutation) Fields() []string { if m.status != nil { fields = append(fields, apikey.FieldStatus) } + if m.last_used_at != nil { + fields = append(fields, apikey.FieldLastUsedAt) + } if m.ip_whitelist != nil { fields = append(fields, apikey.FieldIPWhitelist) } @@ -1008,6 +1065,8 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) { return m.GroupID() case apikey.FieldStatus: return m.Status() + case apikey.FieldLastUsedAt: + return m.LastUsedAt() case apikey.FieldIPWhitelist: return m.IPWhitelist() case apikey.FieldIPBlacklist: @@ -1043,6 +1102,8 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldGroupID(ctx) case apikey.FieldStatus: return m.OldStatus(ctx) + case apikey.FieldLastUsedAt: + return m.OldLastUsedAt(ctx) case apikey.FieldIPWhitelist: return m.OldIPWhitelist(ctx) case apikey.FieldIPBlacklist: @@ -1118,6 +1179,13 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { } m.SetStatus(v) return nil + case apikey.FieldLastUsedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastUsedAt(v) + return nil case apikey.FieldIPWhitelist: v, ok := value.([]string) if !ok { @@ -1216,6 +1284,9 @@ func (m *APIKeyMutation) ClearedFields() []string { if m.FieldCleared(apikey.FieldGroupID) { fields = append(fields, apikey.FieldGroupID) } + if m.FieldCleared(apikey.FieldLastUsedAt) { + fields = append(fields, apikey.FieldLastUsedAt) + } if m.FieldCleared(apikey.FieldIPWhitelist) { fields = append(fields, apikey.FieldIPWhitelist) } @@ -1245,6 +1316,9 @@ func (m *APIKeyMutation) ClearField(name string) error { case apikey.FieldGroupID: m.ClearGroupID() return nil + case apikey.FieldLastUsedAt: + m.ClearLastUsedAt() + return nil case apikey.FieldIPWhitelist: m.ClearIPWhitelist() return nil @@ -1286,6 +1360,9 @@ func (m *APIKeyMutation) ResetField(name string) error { case apikey.FieldStatus: m.ResetStatus() return nil + case apikey.FieldLastUsedAt: + m.ResetLastUsedAt() + return nil case apikey.FieldIPWhitelist: m.ResetIPWhitelist() return nil @@ -1428,48 +1505,50 @@ func (m *APIKeyMutation) ResetEdge(name string) error { // AccountMutation represents an operation that mutates the Account nodes in the graph. type AccountMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - notes *string - platform *string - _type *string - credentials *map[string]interface{} - extra *map[string]interface{} - concurrency *int - addconcurrency *int - priority *int - addpriority *int - rate_multiplier *float64 - addrate_multiplier *float64 - status *string - error_message *string - last_used_at *time.Time - expires_at *time.Time - auto_pause_on_expired *bool - schedulable *bool - rate_limited_at *time.Time - rate_limit_reset_at *time.Time - overload_until *time.Time - session_window_start *time.Time - session_window_end *time.Time - session_window_status *string - clearedFields map[string]struct{} - groups map[int64]struct{} - removedgroups map[int64]struct{} - clearedgroups bool - proxy *int64 - clearedproxy bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - done bool - oldValue func(context.Context) (*Account, error) - predicates []predicate.Account + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + notes *string + platform *string + _type *string + credentials *map[string]interface{} + extra *map[string]interface{} + concurrency *int + addconcurrency *int + priority *int + addpriority *int + rate_multiplier *float64 + addrate_multiplier *float64 + status *string + error_message *string + last_used_at *time.Time + expires_at *time.Time + auto_pause_on_expired *bool + schedulable *bool + rate_limited_at *time.Time + rate_limit_reset_at *time.Time + overload_until *time.Time + temp_unschedulable_until *time.Time + temp_unschedulable_reason *string + session_window_start *time.Time + session_window_end *time.Time + session_window_status *string + clearedFields map[string]struct{} + groups map[int64]struct{} + removedgroups map[int64]struct{} + clearedgroups bool + proxy *int64 + clearedproxy bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + done bool + oldValue func(context.Context) (*Account, error) + predicates []predicate.Account } var _ ent.Mutation = (*AccountMutation)(nil) @@ -2539,6 +2618,104 @@ func (m *AccountMutation) ResetOverloadUntil() { delete(m.clearedFields, account.FieldOverloadUntil) } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (m *AccountMutation) SetTempUnschedulableUntil(t time.Time) { + m.temp_unschedulable_until = &t +} + +// TempUnschedulableUntil returns the value of the "temp_unschedulable_until" field in the mutation. +func (m *AccountMutation) TempUnschedulableUntil() (r time.Time, exists bool) { + v := m.temp_unschedulable_until + if v == nil { + return + } + return *v, true +} + +// OldTempUnschedulableUntil returns the old "temp_unschedulable_until" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldTempUnschedulableUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTempUnschedulableUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTempUnschedulableUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTempUnschedulableUntil: %w", err) + } + return oldValue.TempUnschedulableUntil, nil +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (m *AccountMutation) ClearTempUnschedulableUntil() { + m.temp_unschedulable_until = nil + m.clearedFields[account.FieldTempUnschedulableUntil] = struct{}{} +} + +// TempUnschedulableUntilCleared returns if the "temp_unschedulable_until" field was cleared in this mutation. +func (m *AccountMutation) TempUnschedulableUntilCleared() bool { + _, ok := m.clearedFields[account.FieldTempUnschedulableUntil] + return ok +} + +// ResetTempUnschedulableUntil resets all changes to the "temp_unschedulable_until" field. +func (m *AccountMutation) ResetTempUnschedulableUntil() { + m.temp_unschedulable_until = nil + delete(m.clearedFields, account.FieldTempUnschedulableUntil) +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (m *AccountMutation) SetTempUnschedulableReason(s string) { + m.temp_unschedulable_reason = &s +} + +// TempUnschedulableReason returns the value of the "temp_unschedulable_reason" field in the mutation. +func (m *AccountMutation) TempUnschedulableReason() (r string, exists bool) { + v := m.temp_unschedulable_reason + if v == nil { + return + } + return *v, true +} + +// OldTempUnschedulableReason returns the old "temp_unschedulable_reason" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldTempUnschedulableReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTempUnschedulableReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTempUnschedulableReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTempUnschedulableReason: %w", err) + } + return oldValue.TempUnschedulableReason, nil +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (m *AccountMutation) ClearTempUnschedulableReason() { + m.temp_unschedulable_reason = nil + m.clearedFields[account.FieldTempUnschedulableReason] = struct{}{} +} + +// TempUnschedulableReasonCleared returns if the "temp_unschedulable_reason" field was cleared in this mutation. +func (m *AccountMutation) TempUnschedulableReasonCleared() bool { + _, ok := m.clearedFields[account.FieldTempUnschedulableReason] + return ok +} + +// ResetTempUnschedulableReason resets all changes to the "temp_unschedulable_reason" field. +func (m *AccountMutation) ResetTempUnschedulableReason() { + m.temp_unschedulable_reason = nil + delete(m.clearedFields, account.FieldTempUnschedulableReason) +} + // SetSessionWindowStart sets the "session_window_start" field. func (m *AccountMutation) SetSessionWindowStart(t time.Time) { m.session_window_start = &t @@ -2855,7 +3032,7 @@ func (m *AccountMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AccountMutation) Fields() []string { - fields := make([]string, 0, 25) + fields := make([]string, 0, 27) if m.created_at != nil { fields = append(fields, account.FieldCreatedAt) } @@ -2922,6 +3099,12 @@ func (m *AccountMutation) Fields() []string { if m.overload_until != nil { fields = append(fields, account.FieldOverloadUntil) } + if m.temp_unschedulable_until != nil { + fields = append(fields, account.FieldTempUnschedulableUntil) + } + if m.temp_unschedulable_reason != nil { + fields = append(fields, account.FieldTempUnschedulableReason) + } if m.session_window_start != nil { fields = append(fields, account.FieldSessionWindowStart) } @@ -2983,6 +3166,10 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) { return m.RateLimitResetAt() case account.FieldOverloadUntil: return m.OverloadUntil() + case account.FieldTempUnschedulableUntil: + return m.TempUnschedulableUntil() + case account.FieldTempUnschedulableReason: + return m.TempUnschedulableReason() case account.FieldSessionWindowStart: return m.SessionWindowStart() case account.FieldSessionWindowEnd: @@ -3042,6 +3229,10 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldRateLimitResetAt(ctx) case account.FieldOverloadUntil: return m.OldOverloadUntil(ctx) + case account.FieldTempUnschedulableUntil: + return m.OldTempUnschedulableUntil(ctx) + case account.FieldTempUnschedulableReason: + return m.OldTempUnschedulableReason(ctx) case account.FieldSessionWindowStart: return m.OldSessionWindowStart(ctx) case account.FieldSessionWindowEnd: @@ -3211,6 +3402,20 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error { } m.SetOverloadUntil(v) return nil + case account.FieldTempUnschedulableUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTempUnschedulableUntil(v) + return nil + case account.FieldTempUnschedulableReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTempUnschedulableReason(v) + return nil case account.FieldSessionWindowStart: v, ok := value.(time.Time) if !ok { @@ -3328,6 +3533,12 @@ func (m *AccountMutation) ClearedFields() []string { if m.FieldCleared(account.FieldOverloadUntil) { fields = append(fields, account.FieldOverloadUntil) } + if m.FieldCleared(account.FieldTempUnschedulableUntil) { + fields = append(fields, account.FieldTempUnschedulableUntil) + } + if m.FieldCleared(account.FieldTempUnschedulableReason) { + fields = append(fields, account.FieldTempUnschedulableReason) + } if m.FieldCleared(account.FieldSessionWindowStart) { fields = append(fields, account.FieldSessionWindowStart) } @@ -3378,6 +3589,12 @@ func (m *AccountMutation) ClearField(name string) error { case account.FieldOverloadUntil: m.ClearOverloadUntil() return nil + case account.FieldTempUnschedulableUntil: + m.ClearTempUnschedulableUntil() + return nil + case account.FieldTempUnschedulableReason: + m.ClearTempUnschedulableReason() + return nil case account.FieldSessionWindowStart: m.ClearSessionWindowStart() return nil @@ -3461,6 +3678,12 @@ func (m *AccountMutation) ResetField(name string) error { case account.FieldOverloadUntil: m.ResetOverloadUntil() return nil + case account.FieldTempUnschedulableUntil: + m.ResetTempUnschedulableUntil() + return nil + case account.FieldTempUnschedulableReason: + m.ResetTempUnschedulableReason() + return nil case account.FieldSessionWindowStart: m.ResetSessionWindowStart() return nil @@ -5776,6 +5999,7 @@ type ErrorPassthroughRuleMutation struct { addresponse_code *int passthrough_body *bool custom_message *string + skip_monitoring *bool description *string clearedFields map[string]struct{} done bool @@ -6503,6 +6727,42 @@ func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() { delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage) } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) { + m.skip_monitoring = &b +} + +// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation. +func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) { + v := m.skip_monitoring + if v == nil { + return + } + return *v, true +} + +// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSkipMonitoring requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err) + } + return oldValue.SkipMonitoring, nil +} + +// ResetSkipMonitoring resets all changes to the "skip_monitoring" field. +func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() { + m.skip_monitoring = nil +} + // SetDescription sets the "description" field. func (m *ErrorPassthroughRuleMutation) SetDescription(s string) { m.description = &s @@ -6586,7 +6846,7 @@ func (m *ErrorPassthroughRuleMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *ErrorPassthroughRuleMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 15) if m.created_at != nil { fields = append(fields, errorpassthroughrule.FieldCreatedAt) } @@ -6626,6 +6886,9 @@ func (m *ErrorPassthroughRuleMutation) Fields() []string { if m.custom_message != nil { fields = append(fields, errorpassthroughrule.FieldCustomMessage) } + if m.skip_monitoring != nil { + fields = append(fields, errorpassthroughrule.FieldSkipMonitoring) + } if m.description != nil { fields = append(fields, errorpassthroughrule.FieldDescription) } @@ -6663,6 +6926,8 @@ func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) { return m.PassthroughBody() case errorpassthroughrule.FieldCustomMessage: return m.CustomMessage() + case errorpassthroughrule.FieldSkipMonitoring: + return m.SkipMonitoring() case errorpassthroughrule.FieldDescription: return m.Description() } @@ -6700,6 +6965,8 @@ func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string return m.OldPassthroughBody(ctx) case errorpassthroughrule.FieldCustomMessage: return m.OldCustomMessage(ctx) + case errorpassthroughrule.FieldSkipMonitoring: + return m.OldSkipMonitoring(ctx) case errorpassthroughrule.FieldDescription: return m.OldDescription(ctx) } @@ -6802,6 +7069,13 @@ func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) er } m.SetCustomMessage(v) return nil + case errorpassthroughrule.FieldSkipMonitoring: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSkipMonitoring(v) + return nil case errorpassthroughrule.FieldDescription: v, ok := value.(string) if !ok { @@ -6963,6 +7237,9 @@ func (m *ErrorPassthroughRuleMutation) ResetField(name string) error { case errorpassthroughrule.FieldCustomMessage: m.ResetCustomMessage() return nil + case errorpassthroughrule.FieldSkipMonitoring: + m.ResetSkipMonitoring() + return nil case errorpassthroughrule.FieldDescription: m.ResetDescription() return nil @@ -7049,6 +7326,16 @@ type GroupMutation struct { addimage_price_2k *float64 image_price_4k *float64 addimage_price_4k *float64 + sora_image_price_360 *float64 + addsora_image_price_360 *float64 + sora_image_price_540 *float64 + addsora_image_price_540 *float64 + sora_video_price_per_request *float64 + addsora_video_price_per_request *float64 + sora_video_price_per_request_hd *float64 + addsora_video_price_per_request_hd *float64 + sora_storage_quota_bytes *int64 + addsora_storage_quota_bytes *int64 claude_code_only *bool fallback_group_id *int64 addfallback_group_id *int64 @@ -7059,6 +7346,8 @@ type GroupMutation struct { mcp_xml_inject *bool supported_model_scopes *[]string appendsupported_model_scopes []string + sort_order *int + addsort_order *int clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -8063,6 +8352,342 @@ func (m *GroupMutation) ResetImagePrice4k() { delete(m.clearedFields, group.FieldImagePrice4k) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (m *GroupMutation) SetSoraImagePrice360(f float64) { + m.sora_image_price_360 = &f + m.addsora_image_price_360 = nil +} + +// SoraImagePrice360 returns the value of the "sora_image_price_360" field in the mutation. +func (m *GroupMutation) SoraImagePrice360() (r float64, exists bool) { + v := m.sora_image_price_360 + if v == nil { + return + } + return *v, true +} + +// OldSoraImagePrice360 returns the old "sora_image_price_360" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraImagePrice360(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraImagePrice360 is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraImagePrice360 requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraImagePrice360: %w", err) + } + return oldValue.SoraImagePrice360, nil +} + +// AddSoraImagePrice360 adds f to the "sora_image_price_360" field. +func (m *GroupMutation) AddSoraImagePrice360(f float64) { + if m.addsora_image_price_360 != nil { + *m.addsora_image_price_360 += f + } else { + m.addsora_image_price_360 = &f + } +} + +// AddedSoraImagePrice360 returns the value that was added to the "sora_image_price_360" field in this mutation. +func (m *GroupMutation) AddedSoraImagePrice360() (r float64, exists bool) { + v := m.addsora_image_price_360 + if v == nil { + return + } + return *v, true +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (m *GroupMutation) ClearSoraImagePrice360() { + m.sora_image_price_360 = nil + m.addsora_image_price_360 = nil + m.clearedFields[group.FieldSoraImagePrice360] = struct{}{} +} + +// SoraImagePrice360Cleared returns if the "sora_image_price_360" field was cleared in this mutation. +func (m *GroupMutation) SoraImagePrice360Cleared() bool { + _, ok := m.clearedFields[group.FieldSoraImagePrice360] + return ok +} + +// ResetSoraImagePrice360 resets all changes to the "sora_image_price_360" field. +func (m *GroupMutation) ResetSoraImagePrice360() { + m.sora_image_price_360 = nil + m.addsora_image_price_360 = nil + delete(m.clearedFields, group.FieldSoraImagePrice360) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (m *GroupMutation) SetSoraImagePrice540(f float64) { + m.sora_image_price_540 = &f + m.addsora_image_price_540 = nil +} + +// SoraImagePrice540 returns the value of the "sora_image_price_540" field in the mutation. +func (m *GroupMutation) SoraImagePrice540() (r float64, exists bool) { + v := m.sora_image_price_540 + if v == nil { + return + } + return *v, true +} + +// OldSoraImagePrice540 returns the old "sora_image_price_540" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraImagePrice540(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraImagePrice540 is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraImagePrice540 requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraImagePrice540: %w", err) + } + return oldValue.SoraImagePrice540, nil +} + +// AddSoraImagePrice540 adds f to the "sora_image_price_540" field. +func (m *GroupMutation) AddSoraImagePrice540(f float64) { + if m.addsora_image_price_540 != nil { + *m.addsora_image_price_540 += f + } else { + m.addsora_image_price_540 = &f + } +} + +// AddedSoraImagePrice540 returns the value that was added to the "sora_image_price_540" field in this mutation. +func (m *GroupMutation) AddedSoraImagePrice540() (r float64, exists bool) { + v := m.addsora_image_price_540 + if v == nil { + return + } + return *v, true +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (m *GroupMutation) ClearSoraImagePrice540() { + m.sora_image_price_540 = nil + m.addsora_image_price_540 = nil + m.clearedFields[group.FieldSoraImagePrice540] = struct{}{} +} + +// SoraImagePrice540Cleared returns if the "sora_image_price_540" field was cleared in this mutation. +func (m *GroupMutation) SoraImagePrice540Cleared() bool { + _, ok := m.clearedFields[group.FieldSoraImagePrice540] + return ok +} + +// ResetSoraImagePrice540 resets all changes to the "sora_image_price_540" field. +func (m *GroupMutation) ResetSoraImagePrice540() { + m.sora_image_price_540 = nil + m.addsora_image_price_540 = nil + delete(m.clearedFields, group.FieldSoraImagePrice540) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (m *GroupMutation) SetSoraVideoPricePerRequest(f float64) { + m.sora_video_price_per_request = &f + m.addsora_video_price_per_request = nil +} + +// SoraVideoPricePerRequest returns the value of the "sora_video_price_per_request" field in the mutation. +func (m *GroupMutation) SoraVideoPricePerRequest() (r float64, exists bool) { + v := m.sora_video_price_per_request + if v == nil { + return + } + return *v, true +} + +// OldSoraVideoPricePerRequest returns the old "sora_video_price_per_request" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraVideoPricePerRequest(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraVideoPricePerRequest is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraVideoPricePerRequest requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequest: %w", err) + } + return oldValue.SoraVideoPricePerRequest, nil +} + +// AddSoraVideoPricePerRequest adds f to the "sora_video_price_per_request" field. +func (m *GroupMutation) AddSoraVideoPricePerRequest(f float64) { + if m.addsora_video_price_per_request != nil { + *m.addsora_video_price_per_request += f + } else { + m.addsora_video_price_per_request = &f + } +} + +// AddedSoraVideoPricePerRequest returns the value that was added to the "sora_video_price_per_request" field in this mutation. +func (m *GroupMutation) AddedSoraVideoPricePerRequest() (r float64, exists bool) { + v := m.addsora_video_price_per_request + if v == nil { + return + } + return *v, true +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (m *GroupMutation) ClearSoraVideoPricePerRequest() { + m.sora_video_price_per_request = nil + m.addsora_video_price_per_request = nil + m.clearedFields[group.FieldSoraVideoPricePerRequest] = struct{}{} +} + +// SoraVideoPricePerRequestCleared returns if the "sora_video_price_per_request" field was cleared in this mutation. +func (m *GroupMutation) SoraVideoPricePerRequestCleared() bool { + _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequest] + return ok +} + +// ResetSoraVideoPricePerRequest resets all changes to the "sora_video_price_per_request" field. +func (m *GroupMutation) ResetSoraVideoPricePerRequest() { + m.sora_video_price_per_request = nil + m.addsora_video_price_per_request = nil + delete(m.clearedFields, group.FieldSoraVideoPricePerRequest) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) SetSoraVideoPricePerRequestHd(f float64) { + m.sora_video_price_per_request_hd = &f + m.addsora_video_price_per_request_hd = nil +} + +// SoraVideoPricePerRequestHd returns the value of the "sora_video_price_per_request_hd" field in the mutation. +func (m *GroupMutation) SoraVideoPricePerRequestHd() (r float64, exists bool) { + v := m.sora_video_price_per_request_hd + if v == nil { + return + } + return *v, true +} + +// OldSoraVideoPricePerRequestHd returns the old "sora_video_price_per_request_hd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraVideoPricePerRequestHd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraVideoPricePerRequestHd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraVideoPricePerRequestHd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequestHd: %w", err) + } + return oldValue.SoraVideoPricePerRequestHd, nil +} + +// AddSoraVideoPricePerRequestHd adds f to the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) AddSoraVideoPricePerRequestHd(f float64) { + if m.addsora_video_price_per_request_hd != nil { + *m.addsora_video_price_per_request_hd += f + } else { + m.addsora_video_price_per_request_hd = &f + } +} + +// AddedSoraVideoPricePerRequestHd returns the value that was added to the "sora_video_price_per_request_hd" field in this mutation. +func (m *GroupMutation) AddedSoraVideoPricePerRequestHd() (r float64, exists bool) { + v := m.addsora_video_price_per_request_hd + if v == nil { + return + } + return *v, true +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) ClearSoraVideoPricePerRequestHd() { + m.sora_video_price_per_request_hd = nil + m.addsora_video_price_per_request_hd = nil + m.clearedFields[group.FieldSoraVideoPricePerRequestHd] = struct{}{} +} + +// SoraVideoPricePerRequestHdCleared returns if the "sora_video_price_per_request_hd" field was cleared in this mutation. +func (m *GroupMutation) SoraVideoPricePerRequestHdCleared() bool { + _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequestHd] + return ok +} + +// ResetSoraVideoPricePerRequestHd resets all changes to the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() { + m.sora_video_price_per_request_hd = nil + m.addsora_video_price_per_request_hd = nil + delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd) +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (m *GroupMutation) SetSoraStorageQuotaBytes(i int64) { + m.sora_storage_quota_bytes = &i + m.addsora_storage_quota_bytes = nil +} + +// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation. +func (m *GroupMutation) SoraStorageQuotaBytes() (r int64, exists bool) { + v := m.sora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err) + } + return oldValue.SoraStorageQuotaBytes, nil +} + +// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field. +func (m *GroupMutation) AddSoraStorageQuotaBytes(i int64) { + if m.addsora_storage_quota_bytes != nil { + *m.addsora_storage_quota_bytes += i + } else { + m.addsora_storage_quota_bytes = &i + } +} + +// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation. +func (m *GroupMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) { + v := m.addsora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field. +func (m *GroupMutation) ResetSoraStorageQuotaBytes() { + m.sora_storage_quota_bytes = nil + m.addsora_storage_quota_bytes = nil +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (m *GroupMutation) SetClaudeCodeOnly(b bool) { m.claude_code_only = &b @@ -8411,6 +9036,62 @@ func (m *GroupMutation) ResetSupportedModelScopes() { m.appendsupported_model_scopes = nil } +// SetSortOrder sets the "sort_order" field. +func (m *GroupMutation) SetSortOrder(i int) { + m.sort_order = &i + m.addsort_order = nil +} + +// SortOrder returns the value of the "sort_order" field in the mutation. +func (m *GroupMutation) SortOrder() (r int, exists bool) { + v := m.sort_order + if v == nil { + return + } + return *v, true +} + +// OldSortOrder returns the old "sort_order" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSortOrder(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSortOrder requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) + } + return oldValue.SortOrder, nil +} + +// AddSortOrder adds i to the "sort_order" field. +func (m *GroupMutation) AddSortOrder(i int) { + if m.addsort_order != nil { + *m.addsort_order += i + } else { + m.addsort_order = &i + } +} + +// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. +func (m *GroupMutation) AddedSortOrder() (r int, exists bool) { + v := m.addsort_order + if v == nil { + return + } + return *v, true +} + +// ResetSortOrder resets all changes to the "sort_order" field. +func (m *GroupMutation) ResetSortOrder() { + m.sort_order = nil + m.addsort_order = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -8769,7 +9450,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 24) + fields := make([]string, 0, 30) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -8821,6 +9502,21 @@ func (m *GroupMutation) Fields() []string { if m.image_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.sora_image_price_360 != nil { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.sora_image_price_540 != nil { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.sora_video_price_per_request != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.sora_video_price_per_request_hd != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } + if m.sora_storage_quota_bytes != nil { + fields = append(fields, group.FieldSoraStorageQuotaBytes) + } if m.claude_code_only != nil { fields = append(fields, group.FieldClaudeCodeOnly) } @@ -8842,6 +9538,9 @@ func (m *GroupMutation) Fields() []string { if m.supported_model_scopes != nil { fields = append(fields, group.FieldSupportedModelScopes) } + if m.sort_order != nil { + fields = append(fields, group.FieldSortOrder) + } return fields } @@ -8884,6 +9583,16 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ImagePrice2k() case group.FieldImagePrice4k: return m.ImagePrice4k() + case group.FieldSoraImagePrice360: + return m.SoraImagePrice360() + case group.FieldSoraImagePrice540: + return m.SoraImagePrice540() + case group.FieldSoraVideoPricePerRequest: + return m.SoraVideoPricePerRequest() + case group.FieldSoraVideoPricePerRequestHd: + return m.SoraVideoPricePerRequestHd() + case group.FieldSoraStorageQuotaBytes: + return m.SoraStorageQuotaBytes() case group.FieldClaudeCodeOnly: return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: @@ -8898,6 +9607,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.McpXMLInject() case group.FieldSupportedModelScopes: return m.SupportedModelScopes() + case group.FieldSortOrder: + return m.SortOrder() } return nil, false } @@ -8941,6 +9652,16 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldImagePrice2k(ctx) case group.FieldImagePrice4k: return m.OldImagePrice4k(ctx) + case group.FieldSoraImagePrice360: + return m.OldSoraImagePrice360(ctx) + case group.FieldSoraImagePrice540: + return m.OldSoraImagePrice540(ctx) + case group.FieldSoraVideoPricePerRequest: + return m.OldSoraVideoPricePerRequest(ctx) + case group.FieldSoraVideoPricePerRequestHd: + return m.OldSoraVideoPricePerRequestHd(ctx) + case group.FieldSoraStorageQuotaBytes: + return m.OldSoraStorageQuotaBytes(ctx) case group.FieldClaudeCodeOnly: return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: @@ -8955,6 +9676,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldMcpXMLInject(ctx) case group.FieldSupportedModelScopes: return m.OldSupportedModelScopes(ctx) + case group.FieldSortOrder: + return m.OldSortOrder(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -9083,6 +9806,41 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetImagePrice4k(v) return nil + case group.FieldSoraImagePrice360: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraImagePrice360(v) + return nil + case group.FieldSoraImagePrice540: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraImagePrice540(v) + return nil + case group.FieldSoraVideoPricePerRequest: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraVideoPricePerRequest(v) + return nil + case group.FieldSoraVideoPricePerRequestHd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraVideoPricePerRequestHd(v) + return nil + case group.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraStorageQuotaBytes(v) + return nil case group.FieldClaudeCodeOnly: v, ok := value.(bool) if !ok { @@ -9132,6 +9890,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetSupportedModelScopes(v) return nil + case group.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSortOrder(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -9164,12 +9929,30 @@ func (m *GroupMutation) AddedFields() []string { if m.addimage_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.addsora_image_price_360 != nil { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.addsora_image_price_540 != nil { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.addsora_video_price_per_request != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.addsora_video_price_per_request_hd != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } + if m.addsora_storage_quota_bytes != nil { + fields = append(fields, group.FieldSoraStorageQuotaBytes) + } if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } if m.addfallback_group_id_on_invalid_request != nil { fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) } + if m.addsort_order != nil { + fields = append(fields, group.FieldSortOrder) + } return fields } @@ -9194,10 +9977,22 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedImagePrice2k() case group.FieldImagePrice4k: return m.AddedImagePrice4k() + case group.FieldSoraImagePrice360: + return m.AddedSoraImagePrice360() + case group.FieldSoraImagePrice540: + return m.AddedSoraImagePrice540() + case group.FieldSoraVideoPricePerRequest: + return m.AddedSoraVideoPricePerRequest() + case group.FieldSoraVideoPricePerRequestHd: + return m.AddedSoraVideoPricePerRequestHd() + case group.FieldSoraStorageQuotaBytes: + return m.AddedSoraStorageQuotaBytes() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() case group.FieldFallbackGroupIDOnInvalidRequest: return m.AddedFallbackGroupIDOnInvalidRequest() + case group.FieldSortOrder: + return m.AddedSortOrder() } return nil, false } @@ -9263,6 +10058,41 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddImagePrice4k(v) return nil + case group.FieldSoraImagePrice360: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraImagePrice360(v) + return nil + case group.FieldSoraImagePrice540: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraImagePrice540(v) + return nil + case group.FieldSoraVideoPricePerRequest: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraVideoPricePerRequest(v) + return nil + case group.FieldSoraVideoPricePerRequestHd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraVideoPricePerRequestHd(v) + return nil + case group.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraStorageQuotaBytes(v) + return nil case group.FieldFallbackGroupID: v, ok := value.(int64) if !ok { @@ -9277,6 +10107,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddFallbackGroupIDOnInvalidRequest(v) return nil + case group.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSortOrder(v) + return nil } return fmt.Errorf("unknown Group numeric field %s", name) } @@ -9309,6 +10146,18 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldImagePrice4k) { fields = append(fields, group.FieldImagePrice4k) } + if m.FieldCleared(group.FieldSoraImagePrice360) { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.FieldCleared(group.FieldSoraImagePrice540) { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.FieldCleared(group.FieldSoraVideoPricePerRequest) { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.FieldCleared(group.FieldSoraVideoPricePerRequestHd) { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } if m.FieldCleared(group.FieldFallbackGroupID) { fields = append(fields, group.FieldFallbackGroupID) } @@ -9356,6 +10205,18 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldImagePrice4k: m.ClearImagePrice4k() return nil + case group.FieldSoraImagePrice360: + m.ClearSoraImagePrice360() + return nil + case group.FieldSoraImagePrice540: + m.ClearSoraImagePrice540() + return nil + case group.FieldSoraVideoPricePerRequest: + m.ClearSoraVideoPricePerRequest() + return nil + case group.FieldSoraVideoPricePerRequestHd: + m.ClearSoraVideoPricePerRequestHd() + return nil case group.FieldFallbackGroupID: m.ClearFallbackGroupID() return nil @@ -9424,6 +10285,21 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldImagePrice4k: m.ResetImagePrice4k() return nil + case group.FieldSoraImagePrice360: + m.ResetSoraImagePrice360() + return nil + case group.FieldSoraImagePrice540: + m.ResetSoraImagePrice540() + return nil + case group.FieldSoraVideoPricePerRequest: + m.ResetSoraVideoPricePerRequest() + return nil + case group.FieldSoraVideoPricePerRequestHd: + m.ResetSoraVideoPricePerRequestHd() + return nil + case group.FieldSoraStorageQuotaBytes: + m.ResetSoraStorageQuotaBytes() + return nil case group.FieldClaudeCodeOnly: m.ResetClaudeCodeOnly() return nil @@ -9445,6 +10321,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldSupportedModelScopes: m.ResetSupportedModelScopes() return nil + case group.FieldSortOrder: + m.ResetSortOrder() + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -9663,6 +10542,988 @@ func (m *GroupMutation) ResetEdge(name string) error { return fmt.Errorf("unknown Group edge %s", name) } +// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph. +type IdempotencyRecordMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + scope *string + idempotency_key_hash *string + request_fingerprint *string + status *string + response_status *int + addresponse_status *int + response_body *string + error_reason *string + locked_until *time.Time + expires_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*IdempotencyRecord, error) + predicates []predicate.IdempotencyRecord +} + +var _ ent.Mutation = (*IdempotencyRecordMutation)(nil) + +// idempotencyrecordOption allows management of the mutation configuration using functional options. +type idempotencyrecordOption func(*IdempotencyRecordMutation) + +// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity. +func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation { + m := &IdempotencyRecordMutation{ + config: c, + op: op, + typ: TypeIdempotencyRecord, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withIdempotencyRecordID sets the ID field of the mutation. +func withIdempotencyRecordID(id int64) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { + var ( + err error + once sync.Once + value *IdempotencyRecord + ) + m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().IdempotencyRecord.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withIdempotencyRecord sets the old IdempotencyRecord of the mutation. +func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { + m.oldValue = func(context.Context) (*IdempotencyRecord, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m IdempotencyRecordMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m IdempotencyRecordMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *IdempotencyRecordMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *IdempotencyRecordMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetScope sets the "scope" field. +func (m *IdempotencyRecordMutation) SetScope(s string) { + m.scope = &s +} + +// Scope returns the value of the "scope" field in the mutation. +func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) { + v := m.scope + if v == nil { + return + } + return *v, true +} + +// OldScope returns the old "scope" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldScope is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldScope requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldScope: %w", err) + } + return oldValue.Scope, nil +} + +// ResetScope resets all changes to the "scope" field. +func (m *IdempotencyRecordMutation) ResetScope() { + m.scope = nil +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) { + m.idempotency_key_hash = &s +} + +// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation. +func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) { + v := m.idempotency_key_hash + if v == nil { + return + } + return *v, true +} + +// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err) + } + return oldValue.IdempotencyKeyHash, nil +} + +// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() { + m.idempotency_key_hash = nil +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) { + m.request_fingerprint = &s +} + +// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation. +func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) { + v := m.request_fingerprint + if v == nil { + return + } + return *v, true +} + +// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequestFingerprint requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err) + } + return oldValue.RequestFingerprint, nil +} + +// ResetRequestFingerprint resets all changes to the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) ResetRequestFingerprint() { + m.request_fingerprint = nil +} + +// SetStatus sets the "status" field. +func (m *IdempotencyRecordMutation) SetStatus(s string) { + m.status = &s +} + +// Status returns the value of the "status" field in the mutation. +func (m *IdempotencyRecordMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *IdempotencyRecordMutation) ResetStatus() { + m.status = nil +} + +// SetResponseStatus sets the "response_status" field. +func (m *IdempotencyRecordMutation) SetResponseStatus(i int) { + m.response_status = &i + m.addresponse_status = nil +} + +// ResponseStatus returns the value of the "response_status" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) { + v := m.response_status + if v == nil { + return + } + return *v, true +} + +// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err) + } + return oldValue.ResponseStatus, nil +} + +// AddResponseStatus adds i to the "response_status" field. +func (m *IdempotencyRecordMutation) AddResponseStatus(i int) { + if m.addresponse_status != nil { + *m.addresponse_status += i + } else { + m.addresponse_status = &i + } +} + +// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation. +func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) { + v := m.addresponse_status + if v == nil { + return + } + return *v, true +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (m *IdempotencyRecordMutation) ClearResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{} +} + +// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus] + return ok +} + +// ResetResponseStatus resets all changes to the "response_status" field. +func (m *IdempotencyRecordMutation) ResetResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseStatus) +} + +// SetResponseBody sets the "response_body" field. +func (m *IdempotencyRecordMutation) SetResponseBody(s string) { + m.response_body = &s +} + +// ResponseBody returns the value of the "response_body" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) { + v := m.response_body + if v == nil { + return + } + return *v, true +} + +// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseBody is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseBody requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseBody: %w", err) + } + return oldValue.ResponseBody, nil +} + +// ClearResponseBody clears the value of the "response_body" field. +func (m *IdempotencyRecordMutation) ClearResponseBody() { + m.response_body = nil + m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{} +} + +// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody] + return ok +} + +// ResetResponseBody resets all changes to the "response_body" field. +func (m *IdempotencyRecordMutation) ResetResponseBody() { + m.response_body = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseBody) +} + +// SetErrorReason sets the "error_reason" field. +func (m *IdempotencyRecordMutation) SetErrorReason(s string) { + m.error_reason = &s +} + +// ErrorReason returns the value of the "error_reason" field in the mutation. +func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) { + v := m.error_reason + if v == nil { + return + } + return *v, true +} + +// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorReason: %w", err) + } + return oldValue.ErrorReason, nil +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (m *IdempotencyRecordMutation) ClearErrorReason() { + m.error_reason = nil + m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{} +} + +// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason] + return ok +} + +// ResetErrorReason resets all changes to the "error_reason" field. +func (m *IdempotencyRecordMutation) ResetErrorReason() { + m.error_reason = nil + delete(m.clearedFields, idempotencyrecord.FieldErrorReason) +} + +// SetLockedUntil sets the "locked_until" field. +func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) { + m.locked_until = &t +} + +// LockedUntil returns the value of the "locked_until" field in the mutation. +func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) { + v := m.locked_until + if v == nil { + return + } + return *v, true +} + +// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLockedUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err) + } + return oldValue.LockedUntil, nil +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (m *IdempotencyRecordMutation) ClearLockedUntil() { + m.locked_until = nil + m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{} +} + +// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) LockedUntilCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil] + return ok +} + +// ResetLockedUntil resets all changes to the "locked_until" field. +func (m *IdempotencyRecordMutation) ResetLockedUntil() { + m.locked_until = nil + delete(m.clearedFields, idempotencyrecord.FieldLockedUntil) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *IdempotencyRecordMutation) ResetExpiresAt() { + m.expires_at = nil +} + +// Where appends a list predicates to the IdempotencyRecordMutation builder. +func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.IdempotencyRecord, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *IdempotencyRecordMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *IdempotencyRecordMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (IdempotencyRecord). +func (m *IdempotencyRecordMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *IdempotencyRecordMutation) Fields() []string { + fields := make([]string, 0, 11) + if m.created_at != nil { + fields = append(fields, idempotencyrecord.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, idempotencyrecord.FieldUpdatedAt) + } + if m.scope != nil { + fields = append(fields, idempotencyrecord.FieldScope) + } + if m.idempotency_key_hash != nil { + fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash) + } + if m.request_fingerprint != nil { + fields = append(fields, idempotencyrecord.FieldRequestFingerprint) + } + if m.status != nil { + fields = append(fields, idempotencyrecord.FieldStatus) + } + if m.response_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + if m.response_body != nil { + fields = append(fields, idempotencyrecord.FieldResponseBody) + } + if m.error_reason != nil { + fields = append(fields, idempotencyrecord.FieldErrorReason) + } + if m.locked_until != nil { + fields = append(fields, idempotencyrecord.FieldLockedUntil) + } + if m.expires_at != nil { + fields = append(fields, idempotencyrecord.FieldExpiresAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.CreatedAt() + case idempotencyrecord.FieldUpdatedAt: + return m.UpdatedAt() + case idempotencyrecord.FieldScope: + return m.Scope() + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.IdempotencyKeyHash() + case idempotencyrecord.FieldRequestFingerprint: + return m.RequestFingerprint() + case idempotencyrecord.FieldStatus: + return m.Status() + case idempotencyrecord.FieldResponseStatus: + return m.ResponseStatus() + case idempotencyrecord.FieldResponseBody: + return m.ResponseBody() + case idempotencyrecord.FieldErrorReason: + return m.ErrorReason() + case idempotencyrecord.FieldLockedUntil: + return m.LockedUntil() + case idempotencyrecord.FieldExpiresAt: + return m.ExpiresAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case idempotencyrecord.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case idempotencyrecord.FieldScope: + return m.OldScope(ctx) + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.OldIdempotencyKeyHash(ctx) + case idempotencyrecord.FieldRequestFingerprint: + return m.OldRequestFingerprint(ctx) + case idempotencyrecord.FieldStatus: + return m.OldStatus(ctx) + case idempotencyrecord.FieldResponseStatus: + return m.OldResponseStatus(ctx) + case idempotencyrecord.FieldResponseBody: + return m.OldResponseBody(ctx) + case idempotencyrecord.FieldErrorReason: + return m.OldErrorReason(ctx) + case idempotencyrecord.FieldLockedUntil: + return m.OldLockedUntil(ctx) + case idempotencyrecord.FieldExpiresAt: + return m.OldExpiresAt(ctx) + } + return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case idempotencyrecord.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case idempotencyrecord.FieldScope: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetScope(v) + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdempotencyKeyHash(v) + return nil + case idempotencyrecord.FieldRequestFingerprint: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequestFingerprint(v) + return nil + case idempotencyrecord.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseStatus(v) + return nil + case idempotencyrecord.FieldResponseBody: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseBody(v) + return nil + case idempotencyrecord.FieldErrorReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorReason(v) + return nil + case idempotencyrecord.FieldLockedUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLockedUntil(v) + return nil + case idempotencyrecord.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + } + return fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *IdempotencyRecordMutation) AddedFields() []string { + var fields []string + if m.addresponse_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldResponseStatus: + return m.AddedResponseStatus() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseStatus(v) + return nil + } + return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *IdempotencyRecordMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(idempotencyrecord.FieldResponseStatus) { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + if m.FieldCleared(idempotencyrecord.FieldResponseBody) { + fields = append(fields, idempotencyrecord.FieldResponseBody) + } + if m.FieldCleared(idempotencyrecord.FieldErrorReason) { + fields = append(fields, idempotencyrecord.FieldErrorReason) + } + if m.FieldCleared(idempotencyrecord.FieldLockedUntil) { + fields = append(fields, idempotencyrecord.FieldLockedUntil) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *IdempotencyRecordMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearField(name string) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + m.ClearResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ClearResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ClearErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ClearLockedUntil() + return nil + } + return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetField(name string) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case idempotencyrecord.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case idempotencyrecord.FieldScope: + m.ResetScope() + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + m.ResetIdempotencyKeyHash() + return nil + case idempotencyrecord.FieldRequestFingerprint: + m.ResetRequestFingerprint() + return nil + case idempotencyrecord.FieldStatus: + m.ResetStatus() + return nil + case idempotencyrecord.FieldResponseStatus: + m.ResetResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ResetResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ResetErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ResetLockedUntil() + return nil + case idempotencyrecord.FieldExpiresAt: + m.ResetExpiresAt() + return nil + } + return fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *IdempotencyRecordMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *IdempotencyRecordMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *IdempotencyRecordMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord edge %s", name) +} + // PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph. type PromoCodeMutation struct { config @@ -13355,6 +15216,494 @@ func (m *RedeemCodeMutation) ResetEdge(name string) error { return fmt.Errorf("unknown RedeemCode edge %s", name) } +// SecuritySecretMutation represents an operation that mutates the SecuritySecret nodes in the graph. +type SecuritySecretMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + key *string + value *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*SecuritySecret, error) + predicates []predicate.SecuritySecret +} + +var _ ent.Mutation = (*SecuritySecretMutation)(nil) + +// securitysecretOption allows management of the mutation configuration using functional options. +type securitysecretOption func(*SecuritySecretMutation) + +// newSecuritySecretMutation creates new mutation for the SecuritySecret entity. +func newSecuritySecretMutation(c config, op Op, opts ...securitysecretOption) *SecuritySecretMutation { + m := &SecuritySecretMutation{ + config: c, + op: op, + typ: TypeSecuritySecret, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withSecuritySecretID sets the ID field of the mutation. +func withSecuritySecretID(id int64) securitysecretOption { + return func(m *SecuritySecretMutation) { + var ( + err error + once sync.Once + value *SecuritySecret + ) + m.oldValue = func(ctx context.Context) (*SecuritySecret, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().SecuritySecret.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withSecuritySecret sets the old SecuritySecret of the mutation. +func withSecuritySecret(node *SecuritySecret) securitysecretOption { + return func(m *SecuritySecretMutation) { + m.oldValue = func(context.Context) (*SecuritySecret, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m SecuritySecretMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m SecuritySecretMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *SecuritySecretMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *SecuritySecretMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().SecuritySecret.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *SecuritySecretMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *SecuritySecretMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the SecuritySecret entity. +// If the SecuritySecret object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SecuritySecretMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *SecuritySecretMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *SecuritySecretMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *SecuritySecretMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the SecuritySecret entity. +// If the SecuritySecret object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SecuritySecretMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *SecuritySecretMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetKey sets the "key" field. +func (m *SecuritySecretMutation) SetKey(s string) { + m.key = &s +} + +// Key returns the value of the "key" field in the mutation. +func (m *SecuritySecretMutation) Key() (r string, exists bool) { + v := m.key + if v == nil { + return + } + return *v, true +} + +// OldKey returns the old "key" field's value of the SecuritySecret entity. +// If the SecuritySecret object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SecuritySecretMutation) OldKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldKey: %w", err) + } + return oldValue.Key, nil +} + +// ResetKey resets all changes to the "key" field. +func (m *SecuritySecretMutation) ResetKey() { + m.key = nil +} + +// SetValue sets the "value" field. +func (m *SecuritySecretMutation) SetValue(s string) { + m.value = &s +} + +// Value returns the value of the "value" field in the mutation. +func (m *SecuritySecretMutation) Value() (r string, exists bool) { + v := m.value + if v == nil { + return + } + return *v, true +} + +// OldValue returns the old "value" field's value of the SecuritySecret entity. +// If the SecuritySecret object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SecuritySecretMutation) OldValue(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldValue is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldValue requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldValue: %w", err) + } + return oldValue.Value, nil +} + +// ResetValue resets all changes to the "value" field. +func (m *SecuritySecretMutation) ResetValue() { + m.value = nil +} + +// Where appends a list predicates to the SecuritySecretMutation builder. +func (m *SecuritySecretMutation) Where(ps ...predicate.SecuritySecret) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the SecuritySecretMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *SecuritySecretMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.SecuritySecret, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *SecuritySecretMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *SecuritySecretMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (SecuritySecret). +func (m *SecuritySecretMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *SecuritySecretMutation) Fields() []string { + fields := make([]string, 0, 4) + if m.created_at != nil { + fields = append(fields, securitysecret.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, securitysecret.FieldUpdatedAt) + } + if m.key != nil { + fields = append(fields, securitysecret.FieldKey) + } + if m.value != nil { + fields = append(fields, securitysecret.FieldValue) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *SecuritySecretMutation) Field(name string) (ent.Value, bool) { + switch name { + case securitysecret.FieldCreatedAt: + return m.CreatedAt() + case securitysecret.FieldUpdatedAt: + return m.UpdatedAt() + case securitysecret.FieldKey: + return m.Key() + case securitysecret.FieldValue: + return m.Value() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *SecuritySecretMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case securitysecret.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case securitysecret.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case securitysecret.FieldKey: + return m.OldKey(ctx) + case securitysecret.FieldValue: + return m.OldValue(ctx) + } + return nil, fmt.Errorf("unknown SecuritySecret field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SecuritySecretMutation) SetField(name string, value ent.Value) error { + switch name { + case securitysecret.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case securitysecret.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case securitysecret.FieldKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetKey(v) + return nil + case securitysecret.FieldValue: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetValue(v) + return nil + } + return fmt.Errorf("unknown SecuritySecret field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *SecuritySecretMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *SecuritySecretMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SecuritySecretMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown SecuritySecret numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *SecuritySecretMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *SecuritySecretMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *SecuritySecretMutation) ClearField(name string) error { + return fmt.Errorf("unknown SecuritySecret nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *SecuritySecretMutation) ResetField(name string) error { + switch name { + case securitysecret.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case securitysecret.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case securitysecret.FieldKey: + m.ResetKey() + return nil + case securitysecret.FieldValue: + m.ResetValue() + return nil + } + return fmt.Errorf("unknown SecuritySecret field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *SecuritySecretMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *SecuritySecretMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *SecuritySecretMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *SecuritySecretMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *SecuritySecretMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *SecuritySecretMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *SecuritySecretMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown SecuritySecret unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *SecuritySecretMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown SecuritySecret edge %s", name) +} + // SettingMutation represents an operation that mutates the Setting nodes in the graph. type SettingMutation struct { config @@ -14920,6 +17269,8 @@ type UsageLogMutation struct { image_count *int addimage_count *int image_size *string + media_type *string + cache_ttl_overridden *bool created_at *time.Time clearedFields map[string]struct{} user *int64 @@ -16546,6 +18897,91 @@ func (m *UsageLogMutation) ResetImageSize() { delete(m.clearedFields, usagelog.FieldImageSize) } +// SetMediaType sets the "media_type" field. +func (m *UsageLogMutation) SetMediaType(s string) { + m.media_type = &s +} + +// MediaType returns the value of the "media_type" field in the mutation. +func (m *UsageLogMutation) MediaType() (r string, exists bool) { + v := m.media_type + if v == nil { + return + } + return *v, true +} + +// OldMediaType returns the old "media_type" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldMediaType(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMediaType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMediaType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMediaType: %w", err) + } + return oldValue.MediaType, nil +} + +// ClearMediaType clears the value of the "media_type" field. +func (m *UsageLogMutation) ClearMediaType() { + m.media_type = nil + m.clearedFields[usagelog.FieldMediaType] = struct{}{} +} + +// MediaTypeCleared returns if the "media_type" field was cleared in this mutation. +func (m *UsageLogMutation) MediaTypeCleared() bool { + _, ok := m.clearedFields[usagelog.FieldMediaType] + return ok +} + +// ResetMediaType resets all changes to the "media_type" field. +func (m *UsageLogMutation) ResetMediaType() { + m.media_type = nil + delete(m.clearedFields, usagelog.FieldMediaType) +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) { + m.cache_ttl_overridden = &b +} + +// CacheTTLOverridden returns the value of the "cache_ttl_overridden" field in the mutation. +func (m *UsageLogMutation) CacheTTLOverridden() (r bool, exists bool) { + v := m.cache_ttl_overridden + if v == nil { + return + } + return *v, true +} + +// OldCacheTTLOverridden returns the old "cache_ttl_overridden" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheTTLOverridden(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheTTLOverridden is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheTTLOverridden requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheTTLOverridden: %w", err) + } + return oldValue.CacheTTLOverridden, nil +} + +// ResetCacheTTLOverridden resets all changes to the "cache_ttl_overridden" field. +func (m *UsageLogMutation) ResetCacheTTLOverridden() { + m.cache_ttl_overridden = nil +} + // SetCreatedAt sets the "created_at" field. func (m *UsageLogMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -16751,7 +19187,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 30) + fields := make([]string, 0, 32) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -16839,6 +19275,12 @@ func (m *UsageLogMutation) Fields() []string { if m.image_size != nil { fields = append(fields, usagelog.FieldImageSize) } + if m.media_type != nil { + fields = append(fields, usagelog.FieldMediaType) + } + if m.cache_ttl_overridden != nil { + fields = append(fields, usagelog.FieldCacheTTLOverridden) + } if m.created_at != nil { fields = append(fields, usagelog.FieldCreatedAt) } @@ -16908,6 +19350,10 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.ImageCount() case usagelog.FieldImageSize: return m.ImageSize() + case usagelog.FieldMediaType: + return m.MediaType() + case usagelog.FieldCacheTTLOverridden: + return m.CacheTTLOverridden() case usagelog.FieldCreatedAt: return m.CreatedAt() } @@ -16977,6 +19423,10 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldImageCount(ctx) case usagelog.FieldImageSize: return m.OldImageSize(ctx) + case usagelog.FieldMediaType: + return m.OldMediaType(ctx) + case usagelog.FieldCacheTTLOverridden: + return m.OldCacheTTLOverridden(ctx) case usagelog.FieldCreatedAt: return m.OldCreatedAt(ctx) } @@ -17191,6 +19641,20 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetImageSize(v) return nil + case usagelog.FieldMediaType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMediaType(v) + return nil + case usagelog.FieldCacheTTLOverridden: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheTTLOverridden(v) + return nil case usagelog.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -17471,6 +19935,9 @@ func (m *UsageLogMutation) ClearedFields() []string { if m.FieldCleared(usagelog.FieldImageSize) { fields = append(fields, usagelog.FieldImageSize) } + if m.FieldCleared(usagelog.FieldMediaType) { + fields = append(fields, usagelog.FieldMediaType) + } return fields } @@ -17509,6 +19976,9 @@ func (m *UsageLogMutation) ClearField(name string) error { case usagelog.FieldImageSize: m.ClearImageSize() return nil + case usagelog.FieldMediaType: + m.ClearMediaType() + return nil } return fmt.Errorf("unknown UsageLog nullable field %s", name) } @@ -17604,6 +20074,12 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldImageSize: m.ResetImageSize() return nil + case usagelog.FieldMediaType: + m.ResetMediaType() + return nil + case usagelog.FieldCacheTTLOverridden: + m.ResetCacheTTLOverridden() + return nil case usagelog.FieldCreatedAt: m.ResetCreatedAt() return nil @@ -17779,6 +20255,10 @@ type UserMutation struct { totp_secret_encrypted *string totp_enabled *bool totp_enabled_at *time.Time + sora_storage_quota_bytes *int64 + addsora_storage_quota_bytes *int64 + sora_storage_used_bytes *int64 + addsora_storage_used_bytes *int64 clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -18493,6 +20973,118 @@ func (m *UserMutation) ResetTotpEnabledAt() { delete(m.clearedFields, user.FieldTotpEnabledAt) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (m *UserMutation) SetSoraStorageQuotaBytes(i int64) { + m.sora_storage_quota_bytes = &i + m.addsora_storage_quota_bytes = nil +} + +// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation. +func (m *UserMutation) SoraStorageQuotaBytes() (r int64, exists bool) { + v := m.sora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err) + } + return oldValue.SoraStorageQuotaBytes, nil +} + +// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field. +func (m *UserMutation) AddSoraStorageQuotaBytes(i int64) { + if m.addsora_storage_quota_bytes != nil { + *m.addsora_storage_quota_bytes += i + } else { + m.addsora_storage_quota_bytes = &i + } +} + +// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation. +func (m *UserMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) { + v := m.addsora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field. +func (m *UserMutation) ResetSoraStorageQuotaBytes() { + m.sora_storage_quota_bytes = nil + m.addsora_storage_quota_bytes = nil +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (m *UserMutation) SetSoraStorageUsedBytes(i int64) { + m.sora_storage_used_bytes = &i + m.addsora_storage_used_bytes = nil +} + +// SoraStorageUsedBytes returns the value of the "sora_storage_used_bytes" field in the mutation. +func (m *UserMutation) SoraStorageUsedBytes() (r int64, exists bool) { + v := m.sora_storage_used_bytes + if v == nil { + return + } + return *v, true +} + +// OldSoraStorageUsedBytes returns the old "sora_storage_used_bytes" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSoraStorageUsedBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraStorageUsedBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraStorageUsedBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraStorageUsedBytes: %w", err) + } + return oldValue.SoraStorageUsedBytes, nil +} + +// AddSoraStorageUsedBytes adds i to the "sora_storage_used_bytes" field. +func (m *UserMutation) AddSoraStorageUsedBytes(i int64) { + if m.addsora_storage_used_bytes != nil { + *m.addsora_storage_used_bytes += i + } else { + m.addsora_storage_used_bytes = &i + } +} + +// AddedSoraStorageUsedBytes returns the value that was added to the "sora_storage_used_bytes" field in this mutation. +func (m *UserMutation) AddedSoraStorageUsedBytes() (r int64, exists bool) { + v := m.addsora_storage_used_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSoraStorageUsedBytes resets all changes to the "sora_storage_used_bytes" field. +func (m *UserMutation) ResetSoraStorageUsedBytes() { + m.sora_storage_used_bytes = nil + m.addsora_storage_used_bytes = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -19013,7 +21605,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 16) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -19056,6 +21648,12 @@ func (m *UserMutation) Fields() []string { if m.totp_enabled_at != nil { fields = append(fields, user.FieldTotpEnabledAt) } + if m.sora_storage_quota_bytes != nil { + fields = append(fields, user.FieldSoraStorageQuotaBytes) + } + if m.sora_storage_used_bytes != nil { + fields = append(fields, user.FieldSoraStorageUsedBytes) + } return fields } @@ -19092,6 +21690,10 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabled() case user.FieldTotpEnabledAt: return m.TotpEnabledAt() + case user.FieldSoraStorageQuotaBytes: + return m.SoraStorageQuotaBytes() + case user.FieldSoraStorageUsedBytes: + return m.SoraStorageUsedBytes() } return nil, false } @@ -19129,6 +21731,10 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabled(ctx) case user.FieldTotpEnabledAt: return m.OldTotpEnabledAt(ctx) + case user.FieldSoraStorageQuotaBytes: + return m.OldSoraStorageQuotaBytes(ctx) + case user.FieldSoraStorageUsedBytes: + return m.OldSoraStorageUsedBytes(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -19236,6 +21842,20 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotpEnabledAt(v) return nil + case user.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraStorageQuotaBytes(v) + return nil + case user.FieldSoraStorageUsedBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraStorageUsedBytes(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -19250,6 +21870,12 @@ func (m *UserMutation) AddedFields() []string { if m.addconcurrency != nil { fields = append(fields, user.FieldConcurrency) } + if m.addsora_storage_quota_bytes != nil { + fields = append(fields, user.FieldSoraStorageQuotaBytes) + } + if m.addsora_storage_used_bytes != nil { + fields = append(fields, user.FieldSoraStorageUsedBytes) + } return fields } @@ -19262,6 +21888,10 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) { return m.AddedBalance() case user.FieldConcurrency: return m.AddedConcurrency() + case user.FieldSoraStorageQuotaBytes: + return m.AddedSoraStorageQuotaBytes() + case user.FieldSoraStorageUsedBytes: + return m.AddedSoraStorageUsedBytes() } return nil, false } @@ -19285,6 +21915,20 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { } m.AddConcurrency(v) return nil + case user.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraStorageQuotaBytes(v) + return nil + case user.FieldSoraStorageUsedBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraStorageUsedBytes(v) + return nil } return fmt.Errorf("unknown User numeric field %s", name) } @@ -19375,6 +22019,12 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotpEnabledAt: m.ResetTotpEnabledAt() return nil + case user.FieldSoraStorageQuotaBytes: + m.ResetSoraStorageQuotaBytes() + return nil + case user.FieldSoraStorageUsedBytes: + m.ResetSoraStorageUsedBytes() + return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index c12955ef..89d933fc 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -27,6 +27,9 @@ type ErrorPassthroughRule func(*sql.Selector) // Group is the predicate function for group builders. type Group func(*sql.Selector) +// IdempotencyRecord is the predicate function for idempotencyrecord builders. +type IdempotencyRecord func(*sql.Selector) + // PromoCode is the predicate function for promocode builders. type PromoCode func(*sql.Selector) @@ -39,6 +42,9 @@ type Proxy func(*sql.Selector) // RedeemCode is the predicate function for redeemcode builders. type RedeemCode func(*sql.Selector) +// SecuritySecret is the predicate function for securitysecret builders. +type SecuritySecret func(*sql.Selector) + // Setting is the predicate function for setting builders. type Setting func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 4b3c1a4f..65531aae 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -12,11 +12,13 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/schema" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -93,11 +95,11 @@ func init() { // apikey.StatusValidator is a validator for the "status" field. It is called by the builders before save. apikey.StatusValidator = apikeyDescStatus.Validators[0].(func(string) error) // apikeyDescQuota is the schema descriptor for quota field. - apikeyDescQuota := apikeyFields[7].Descriptor() + apikeyDescQuota := apikeyFields[8].Descriptor() // apikey.DefaultQuota holds the default value on creation for the quota field. apikey.DefaultQuota = apikeyDescQuota.Default.(float64) // apikeyDescQuotaUsed is the schema descriptor for quota_used field. - apikeyDescQuotaUsed := apikeyFields[8].Descriptor() + apikeyDescQuotaUsed := apikeyFields[9].Descriptor() // apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field. apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64) accountMixin := schema.Account{}.Mixin() @@ -208,7 +210,7 @@ func init() { // account.DefaultSchedulable holds the default value on creation for the schedulable field. account.DefaultSchedulable = accountDescSchedulable.Default.(bool) // accountDescSessionWindowStatus is the schema descriptor for session_window_status field. - accountDescSessionWindowStatus := accountFields[21].Descriptor() + accountDescSessionWindowStatus := accountFields[23].Descriptor() // account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save. account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error) accountgroupFields := schema.AccountGroup{}.Fields() @@ -326,6 +328,10 @@ func init() { errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor() // errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field. errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool) + // errorpassthroughruleDescSkipMonitoring is the schema descriptor for skip_monitoring field. + errorpassthroughruleDescSkipMonitoring := errorpassthroughruleFields[11].Descriptor() + // errorpassthroughrule.DefaultSkipMonitoring holds the default value on creation for the skip_monitoring field. + errorpassthroughrule.DefaultSkipMonitoring = errorpassthroughruleDescSkipMonitoring.Default.(bool) groupMixin := schema.Group{}.Mixin() groupMixinHooks1 := groupMixin[1].Hooks() group.Hooks[0] = groupMixinHooks1[0] @@ -393,22 +399,65 @@ func init() { groupDescDefaultValidityDays := groupFields[10].Descriptor() // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) + // groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field. + groupDescSoraStorageQuotaBytes := groupFields[18].Descriptor() + // group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field. + group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64) // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. - groupDescClaudeCodeOnly := groupFields[14].Descriptor() + groupDescClaudeCodeOnly := groupFields[19].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[18].Descriptor() + groupDescModelRoutingEnabled := groupFields[23].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. - groupDescMcpXMLInject := groupFields[19].Descriptor() + groupDescMcpXMLInject := groupFields[24].Descriptor() // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. - groupDescSupportedModelScopes := groupFields[20].Descriptor() + groupDescSupportedModelScopes := groupFields[25].Descriptor() // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) + // groupDescSortOrder is the schema descriptor for sort_order field. + groupDescSortOrder := groupFields[26].Descriptor() + // group.DefaultSortOrder holds the default value on creation for the sort_order field. + group.DefaultSortOrder = groupDescSortOrder.Default.(int) + idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() + idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() + _ = idempotencyrecordMixinFields0 + idempotencyrecordFields := schema.IdempotencyRecord{}.Fields() + _ = idempotencyrecordFields + // idempotencyrecordDescCreatedAt is the schema descriptor for created_at field. + idempotencyrecordDescCreatedAt := idempotencyrecordMixinFields0[0].Descriptor() + // idempotencyrecord.DefaultCreatedAt holds the default value on creation for the created_at field. + idempotencyrecord.DefaultCreatedAt = idempotencyrecordDescCreatedAt.Default.(func() time.Time) + // idempotencyrecordDescUpdatedAt is the schema descriptor for updated_at field. + idempotencyrecordDescUpdatedAt := idempotencyrecordMixinFields0[1].Descriptor() + // idempotencyrecord.DefaultUpdatedAt holds the default value on creation for the updated_at field. + idempotencyrecord.DefaultUpdatedAt = idempotencyrecordDescUpdatedAt.Default.(func() time.Time) + // idempotencyrecord.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + idempotencyrecord.UpdateDefaultUpdatedAt = idempotencyrecordDescUpdatedAt.UpdateDefault.(func() time.Time) + // idempotencyrecordDescScope is the schema descriptor for scope field. + idempotencyrecordDescScope := idempotencyrecordFields[0].Descriptor() + // idempotencyrecord.ScopeValidator is a validator for the "scope" field. It is called by the builders before save. + idempotencyrecord.ScopeValidator = idempotencyrecordDescScope.Validators[0].(func(string) error) + // idempotencyrecordDescIdempotencyKeyHash is the schema descriptor for idempotency_key_hash field. + idempotencyrecordDescIdempotencyKeyHash := idempotencyrecordFields[1].Descriptor() + // idempotencyrecord.IdempotencyKeyHashValidator is a validator for the "idempotency_key_hash" field. It is called by the builders before save. + idempotencyrecord.IdempotencyKeyHashValidator = idempotencyrecordDescIdempotencyKeyHash.Validators[0].(func(string) error) + // idempotencyrecordDescRequestFingerprint is the schema descriptor for request_fingerprint field. + idempotencyrecordDescRequestFingerprint := idempotencyrecordFields[2].Descriptor() + // idempotencyrecord.RequestFingerprintValidator is a validator for the "request_fingerprint" field. It is called by the builders before save. + idempotencyrecord.RequestFingerprintValidator = idempotencyrecordDescRequestFingerprint.Validators[0].(func(string) error) + // idempotencyrecordDescStatus is the schema descriptor for status field. + idempotencyrecordDescStatus := idempotencyrecordFields[3].Descriptor() + // idempotencyrecord.StatusValidator is a validator for the "status" field. It is called by the builders before save. + idempotencyrecord.StatusValidator = idempotencyrecordDescStatus.Validators[0].(func(string) error) + // idempotencyrecordDescErrorReason is the schema descriptor for error_reason field. + idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor() + // idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. + idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. @@ -594,6 +643,43 @@ func init() { redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor() // redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field. redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int) + securitysecretMixin := schema.SecuritySecret{}.Mixin() + securitysecretMixinFields0 := securitysecretMixin[0].Fields() + _ = securitysecretMixinFields0 + securitysecretFields := schema.SecuritySecret{}.Fields() + _ = securitysecretFields + // securitysecretDescCreatedAt is the schema descriptor for created_at field. + securitysecretDescCreatedAt := securitysecretMixinFields0[0].Descriptor() + // securitysecret.DefaultCreatedAt holds the default value on creation for the created_at field. + securitysecret.DefaultCreatedAt = securitysecretDescCreatedAt.Default.(func() time.Time) + // securitysecretDescUpdatedAt is the schema descriptor for updated_at field. + securitysecretDescUpdatedAt := securitysecretMixinFields0[1].Descriptor() + // securitysecret.DefaultUpdatedAt holds the default value on creation for the updated_at field. + securitysecret.DefaultUpdatedAt = securitysecretDescUpdatedAt.Default.(func() time.Time) + // securitysecret.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + securitysecret.UpdateDefaultUpdatedAt = securitysecretDescUpdatedAt.UpdateDefault.(func() time.Time) + // securitysecretDescKey is the schema descriptor for key field. + securitysecretDescKey := securitysecretFields[0].Descriptor() + // securitysecret.KeyValidator is a validator for the "key" field. It is called by the builders before save. + securitysecret.KeyValidator = func() func(string) error { + validators := securitysecretDescKey.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(key string) error { + for _, fn := range fns { + if err := fn(key); err != nil { + return err + } + } + return nil + } + }() + // securitysecretDescValue is the schema descriptor for value field. + securitysecretDescValue := securitysecretFields[1].Descriptor() + // securitysecret.ValueValidator is a validator for the "value" field. It is called by the builders before save. + securitysecret.ValueValidator = securitysecretDescValue.Validators[0].(func(string) error) settingFields := schema.Setting{}.Fields() _ = settingFields // settingDescKey is the schema descriptor for key field. @@ -771,8 +857,16 @@ func init() { usagelogDescImageSize := usagelogFields[28].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) + // usagelogDescMediaType is the schema descriptor for media_type field. + usagelogDescMediaType := usagelogFields[29].Descriptor() + // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. + usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) + // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. + usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor() + // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. + usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[29].Descriptor() + usagelogDescCreatedAt := usagelogFields[31].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() @@ -864,6 +958,14 @@ func init() { userDescTotpEnabled := userFields[9].Descriptor() // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) + // userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field. + userDescSoraStorageQuotaBytes := userFields[11].Descriptor() + // user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field. + user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64) + // userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field. + userDescSoraStorageUsedBytes := userFields[12].Descriptor() + // user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field. + user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() _ = userallowedgroupFields // userallowedgroupDescCreatedAt is the schema descriptor for created_at field. diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go index 1cfecc2d..443f9e09 100644 --- a/backend/ent/schema/account.go +++ b/backend/ent/schema/account.go @@ -164,6 +164,19 @@ func (Account) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + // temp_unschedulable_until: 临时不可调度状态解除时间 + // 当命中临时不可调度规则时设置,在此时间前调度器应跳过该账号 + field.Time("temp_unschedulable_until"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + + // temp_unschedulable_reason: 临时不可调度原因,便于排障审计 + field.String("temp_unschedulable_reason"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + // session_window_*: 会话窗口相关字段 // 用于管理某些需要会话时间窗口的 API(如 Claude Pro) field.Time("session_window_start"). @@ -213,6 +226,9 @@ func (Account) Indexes() []ent.Index { index.Fields("rate_limited_at"), // 筛选速率限制账户 index.Fields("rate_limit_reset_at"), // 筛选速率限制解除时间 index.Fields("overload_until"), // 筛选过载账户 - index.Fields("deleted_at"), // 软删除查询优化 + // 调度热路径复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐) + index.Fields("platform", "priority"), + index.Fields("priority", "status"), + index.Fields("deleted_at"), // 软删除查询优化 } } diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go index 26d52cb0..c1ac7ac3 100644 --- a/backend/ent/schema/api_key.go +++ b/backend/ent/schema/api_key.go @@ -47,6 +47,10 @@ func (APIKey) Fields() []ent.Field { field.String("status"). MaxLen(20). Default(domain.StatusActive), + field.Time("last_used_at"). + Optional(). + Nillable(). + Comment("Last usage time of this API key"), field.JSON("ip_whitelist", []string{}). Optional(). Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"), @@ -95,6 +99,7 @@ func (APIKey) Indexes() []ent.Index { index.Fields("group_id"), index.Fields("status"), index.Fields("deleted_at"), + index.Fields("last_used_at"), // Index for quota queries index.Fields("quota", "quota_used"), index.Fields("expires_at"), diff --git a/backend/ent/schema/error_passthrough_rule.go b/backend/ent/schema/error_passthrough_rule.go index 4a861f38..63a81230 100644 --- a/backend/ent/schema/error_passthrough_rule.go +++ b/backend/ent/schema/error_passthrough_rule.go @@ -105,6 +105,12 @@ func (ErrorPassthroughRule) Fields() []ent.Field { Optional(). Nillable(), + // skip_monitoring: 是否跳过运维监控记录 + // true: 匹配此规则的错误不会被记录到 ops_error_logs + // false: 正常记录到运维监控(默认行为) + field.Bool("skip_monitoring"). + Default(false), + // description: 规则描述,用于说明规则的用途 field.Text("description"). Optional(). diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 8a3c1a90..3fcf8674 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -87,6 +87,28 @@ func (Group) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + // Sora 按次计费配置(阶段 1) + field.Float("sora_image_price_360"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_image_price_540"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_video_price_per_request"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_video_price_per_request_hd"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + + // Sora 存储配额 + field.Int64("sora_storage_quota_bytes"). + Default(0), + // Claude Code 客户端限制 (added by migration 029) field.Bool("claude_code_only"). Default(false). @@ -121,6 +143,11 @@ func (Group) Fields() []ent.Field { Default([]string{"claude", "gemini_text", "gemini_image"}). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). Comment("支持的模型系列:claude, gemini_text, gemini_image"), + + // 分组排序 (added by migration 052) + field.Int("sort_order"). + Default(0). + Comment("分组显示排序,数值越小越靠前"), } } @@ -149,5 +176,6 @@ func (Group) Indexes() []ent.Index { index.Fields("subscription_type"), index.Fields("is_exclusive"), index.Fields("deleted_at"), + index.Fields("sort_order"), } } diff --git a/backend/ent/schema/idempotency_record.go b/backend/ent/schema/idempotency_record.go new file mode 100644 index 00000000..ed09ad65 --- /dev/null +++ b/backend/ent/schema/idempotency_record.go @@ -0,0 +1,50 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// IdempotencyRecord 幂等请求记录表。 +type IdempotencyRecord struct { + ent.Schema +} + +func (IdempotencyRecord) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "idempotency_records"}, + } +} + +func (IdempotencyRecord) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (IdempotencyRecord) Fields() []ent.Field { + return []ent.Field{ + field.String("scope").MaxLen(128), + field.String("idempotency_key_hash").MaxLen(64), + field.String("request_fingerprint").MaxLen(64), + field.String("status").MaxLen(32), + field.Int("response_status").Optional().Nillable(), + field.String("response_body").Optional().Nillable(), + field.String("error_reason").MaxLen(128).Optional().Nillable(), + field.Time("locked_until").Optional().Nillable(), + field.Time("expires_at"), + } +} + +func (IdempotencyRecord) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("scope", "idempotency_key_hash").Unique(), + index.Fields("expires_at"), + index.Fields("status", "locked_until"), + } +} diff --git a/backend/ent/schema/security_secret.go b/backend/ent/schema/security_secret.go new file mode 100644 index 00000000..ffe6d348 --- /dev/null +++ b/backend/ent/schema/security_secret.go @@ -0,0 +1,42 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" +) + +// SecuritySecret 存储系统级安全密钥(如 JWT 签名密钥、TOTP 加密密钥)。 +type SecuritySecret struct { + ent.Schema +} + +func (SecuritySecret) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "security_secrets"}, + } +} + +func (SecuritySecret) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (SecuritySecret) Fields() []ent.Field { + return []ent.Field{ + field.String("key"). + MaxLen(100). + NotEmpty(). + Unique(), + field.String("value"). + NotEmpty(). + SchemaType(map[string]string{ + dialect.Postgres: "text", + }), + } +} diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index fc7c7165..dcca1a0a 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -118,6 +118,15 @@ func (UsageLog) Fields() []ent.Field { MaxLen(10). Optional(). Nillable(), + // 媒体类型字段(sora 使用) + field.String("media_type"). + MaxLen(16). + Optional(). + Nillable(), + + // Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费) + field.Bool("cache_ttl_overridden"). + Default(false), // 时间戳(只有 created_at,日志不可修改) field.Time("created_at"). @@ -170,5 +179,7 @@ func (UsageLog) Indexes() []ent.Index { // 复合索引用于时间范围查询 index.Fields("user_id", "created_at"), index.Fields("api_key_id", "created_at"), + // 分组维度时间范围查询(线上由 SQL 迁移创建 group_id IS NOT NULL 的部分索引) + index.Fields("group_id", "created_at"), } } diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index d443ef45..0a3b5d9e 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -72,6 +72,12 @@ func (User) Fields() []ent.Field { field.Time("totp_enabled_at"). Optional(). Nillable(), + + // Sora 存储配额 + field.Int64("sora_storage_quota_bytes"). + Default(0), + field.Int64("sora_storage_used_bytes"). + Default(0), } } diff --git a/backend/ent/schema/user_subscription.go b/backend/ent/schema/user_subscription.go index fa13612b..a81850b1 100644 --- a/backend/ent/schema/user_subscription.go +++ b/backend/ent/schema/user_subscription.go @@ -108,6 +108,8 @@ func (UserSubscription) Indexes() []ent.Index { index.Fields("group_id"), index.Fields("status"), index.Fields("expires_at"), + // 活跃订阅查询复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐) + index.Fields("user_id", "status", "expires_at"), index.Fields("assigned_by"), // 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重新订阅 // 见迁移文件 016_soft_delete_partial_unique_indexes.sql diff --git a/backend/ent/securitysecret.go b/backend/ent/securitysecret.go new file mode 100644 index 00000000..e0e93c91 --- /dev/null +++ b/backend/ent/securitysecret.go @@ -0,0 +1,139 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecret is the model entity for the SecuritySecret schema. +type SecuritySecret struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Key holds the value of the "key" field. + Key string `json:"key,omitempty"` + // Value holds the value of the "value" field. + Value string `json:"value,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*SecuritySecret) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case securitysecret.FieldID: + values[i] = new(sql.NullInt64) + case securitysecret.FieldKey, securitysecret.FieldValue: + values[i] = new(sql.NullString) + case securitysecret.FieldCreatedAt, securitysecret.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the SecuritySecret fields. +func (_m *SecuritySecret) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case securitysecret.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case securitysecret.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case securitysecret.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case securitysecret.FieldKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field key", values[i]) + } else if value.Valid { + _m.Key = value.String + } + case securitysecret.FieldValue: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field value", values[i]) + } else if value.Valid { + _m.Value = value.String + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// GetValue returns the ent.Value that was dynamically selected and assigned to the SecuritySecret. +// This includes values selected through modifiers, order, etc. +func (_m *SecuritySecret) GetValue(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this SecuritySecret. +// Note that you need to call SecuritySecret.Unwrap() before calling this method if this SecuritySecret +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *SecuritySecret) Update() *SecuritySecretUpdateOne { + return NewSecuritySecretClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the SecuritySecret entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *SecuritySecret) Unwrap() *SecuritySecret { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: SecuritySecret is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *SecuritySecret) String() string { + var builder strings.Builder + builder.WriteString("SecuritySecret(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("key=") + builder.WriteString(_m.Key) + builder.WriteString(", ") + builder.WriteString("value=") + builder.WriteString(_m.Value) + builder.WriteByte(')') + return builder.String() +} + +// SecuritySecrets is a parsable slice of SecuritySecret. +type SecuritySecrets []*SecuritySecret diff --git a/backend/ent/securitysecret/securitysecret.go b/backend/ent/securitysecret/securitysecret.go new file mode 100644 index 00000000..4c5d9ef6 --- /dev/null +++ b/backend/ent/securitysecret/securitysecret.go @@ -0,0 +1,86 @@ +// Code generated by ent, DO NOT EDIT. + +package securitysecret + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the securitysecret type in the database. + Label = "security_secret" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldKey holds the string denoting the key field in the database. + FieldKey = "key" + // FieldValue holds the string denoting the value field in the database. + FieldValue = "value" + // Table holds the table name of the securitysecret in the database. + Table = "security_secrets" +) + +// Columns holds all SQL columns for securitysecret fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldKey, + FieldValue, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // KeyValidator is a validator for the "key" field. It is called by the builders before save. + KeyValidator func(string) error + // ValueValidator is a validator for the "value" field. It is called by the builders before save. + ValueValidator func(string) error +) + +// OrderOption defines the ordering options for the SecuritySecret queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByKey orders the results by the key field. +func ByKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldKey, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} diff --git a/backend/ent/securitysecret/where.go b/backend/ent/securitysecret/where.go new file mode 100644 index 00000000..34f50752 --- /dev/null +++ b/backend/ent/securitysecret/where.go @@ -0,0 +1,300 @@ +// Code generated by ent, DO NOT EDIT. + +package securitysecret + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Key applies equality check predicate on the "key" field. It's identical to KeyEQ. +func Key(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldKey, v)) +} + +// Value applies equality check predicate on the "value" field. It's identical to ValueEQ. +func Value(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldValue, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// KeyEQ applies the EQ predicate on the "key" field. +func KeyEQ(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldKey, v)) +} + +// KeyNEQ applies the NEQ predicate on the "key" field. +func KeyNEQ(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldKey, v)) +} + +// KeyIn applies the In predicate on the "key" field. +func KeyIn(vs ...string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldKey, vs...)) +} + +// KeyNotIn applies the NotIn predicate on the "key" field. +func KeyNotIn(vs ...string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldKey, vs...)) +} + +// KeyGT applies the GT predicate on the "key" field. +func KeyGT(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldKey, v)) +} + +// KeyGTE applies the GTE predicate on the "key" field. +func KeyGTE(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldKey, v)) +} + +// KeyLT applies the LT predicate on the "key" field. +func KeyLT(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldKey, v)) +} + +// KeyLTE applies the LTE predicate on the "key" field. +func KeyLTE(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldKey, v)) +} + +// KeyContains applies the Contains predicate on the "key" field. +func KeyContains(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldContains(FieldKey, v)) +} + +// KeyHasPrefix applies the HasPrefix predicate on the "key" field. +func KeyHasPrefix(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldHasPrefix(FieldKey, v)) +} + +// KeyHasSuffix applies the HasSuffix predicate on the "key" field. +func KeyHasSuffix(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldHasSuffix(FieldKey, v)) +} + +// KeyEqualFold applies the EqualFold predicate on the "key" field. +func KeyEqualFold(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEqualFold(FieldKey, v)) +} + +// KeyContainsFold applies the ContainsFold predicate on the "key" field. +func KeyContainsFold(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldContainsFold(FieldKey, v)) +} + +// ValueEQ applies the EQ predicate on the "value" field. +func ValueEQ(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldValue, v)) +} + +// ValueNEQ applies the NEQ predicate on the "value" field. +func ValueNEQ(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldValue, v)) +} + +// ValueIn applies the In predicate on the "value" field. +func ValueIn(vs ...string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldValue, vs...)) +} + +// ValueNotIn applies the NotIn predicate on the "value" field. +func ValueNotIn(vs ...string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldValue, vs...)) +} + +// ValueGT applies the GT predicate on the "value" field. +func ValueGT(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldValue, v)) +} + +// ValueGTE applies the GTE predicate on the "value" field. +func ValueGTE(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldValue, v)) +} + +// ValueLT applies the LT predicate on the "value" field. +func ValueLT(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldValue, v)) +} + +// ValueLTE applies the LTE predicate on the "value" field. +func ValueLTE(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldValue, v)) +} + +// ValueContains applies the Contains predicate on the "value" field. +func ValueContains(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldContains(FieldValue, v)) +} + +// ValueHasPrefix applies the HasPrefix predicate on the "value" field. +func ValueHasPrefix(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldHasPrefix(FieldValue, v)) +} + +// ValueHasSuffix applies the HasSuffix predicate on the "value" field. +func ValueHasSuffix(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldHasSuffix(FieldValue, v)) +} + +// ValueEqualFold applies the EqualFold predicate on the "value" field. +func ValueEqualFold(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEqualFold(FieldValue, v)) +} + +// ValueContainsFold applies the ContainsFold predicate on the "value" field. +func ValueContainsFold(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldContainsFold(FieldValue, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.SecuritySecret) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.SecuritySecret) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.SecuritySecret) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.NotPredicates(p)) +} diff --git a/backend/ent/securitysecret_create.go b/backend/ent/securitysecret_create.go new file mode 100644 index 00000000..397503be --- /dev/null +++ b/backend/ent/securitysecret_create.go @@ -0,0 +1,626 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecretCreate is the builder for creating a SecuritySecret entity. +type SecuritySecretCreate struct { + config + mutation *SecuritySecretMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *SecuritySecretCreate) SetCreatedAt(v time.Time) *SecuritySecretCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *SecuritySecretCreate) SetNillableCreatedAt(v *time.Time) *SecuritySecretCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *SecuritySecretCreate) SetUpdatedAt(v time.Time) *SecuritySecretCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *SecuritySecretCreate) SetNillableUpdatedAt(v *time.Time) *SecuritySecretCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetKey sets the "key" field. +func (_c *SecuritySecretCreate) SetKey(v string) *SecuritySecretCreate { + _c.mutation.SetKey(v) + return _c +} + +// SetValue sets the "value" field. +func (_c *SecuritySecretCreate) SetValue(v string) *SecuritySecretCreate { + _c.mutation.SetValue(v) + return _c +} + +// Mutation returns the SecuritySecretMutation object of the builder. +func (_c *SecuritySecretCreate) Mutation() *SecuritySecretMutation { + return _c.mutation +} + +// Save creates the SecuritySecret in the database. +func (_c *SecuritySecretCreate) Save(ctx context.Context) (*SecuritySecret, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *SecuritySecretCreate) SaveX(ctx context.Context) *SecuritySecret { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SecuritySecretCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SecuritySecretCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *SecuritySecretCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := securitysecret.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := securitysecret.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *SecuritySecretCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "SecuritySecret.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "SecuritySecret.updated_at"`)} + } + if _, ok := _c.mutation.Key(); !ok { + return &ValidationError{Name: "key", err: errors.New(`ent: missing required field "SecuritySecret.key"`)} + } + if v, ok := _c.mutation.Key(); ok { + if err := securitysecret.KeyValidator(v); err != nil { + return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)} + } + } + if _, ok := _c.mutation.Value(); !ok { + return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "SecuritySecret.value"`)} + } + if v, ok := _c.mutation.Value(); ok { + if err := securitysecret.ValueValidator(v); err != nil { + return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)} + } + } + return nil +} + +func (_c *SecuritySecretCreate) sqlSave(ctx context.Context) (*SecuritySecret, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *SecuritySecretCreate) createSpec() (*SecuritySecret, *sqlgraph.CreateSpec) { + var ( + _node = &SecuritySecret{config: _c.config} + _spec = sqlgraph.NewCreateSpec(securitysecret.Table, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(securitysecret.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Key(); ok { + _spec.SetField(securitysecret.FieldKey, field.TypeString, value) + _node.Key = value + } + if value, ok := _c.mutation.Value(); ok { + _spec.SetField(securitysecret.FieldValue, field.TypeString, value) + _node.Value = value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SecuritySecret.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SecuritySecretUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *SecuritySecretCreate) OnConflict(opts ...sql.ConflictOption) *SecuritySecretUpsertOne { + _c.conflict = opts + return &SecuritySecretUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SecuritySecretCreate) OnConflictColumns(columns ...string) *SecuritySecretUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SecuritySecretUpsertOne{ + create: _c, + } +} + +type ( + // SecuritySecretUpsertOne is the builder for "upsert"-ing + // one SecuritySecret node. + SecuritySecretUpsertOne struct { + create *SecuritySecretCreate + } + + // SecuritySecretUpsert is the "OnConflict" setter. + SecuritySecretUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *SecuritySecretUpsert) SetUpdatedAt(v time.Time) *SecuritySecretUpsert { + u.Set(securitysecret.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SecuritySecretUpsert) UpdateUpdatedAt() *SecuritySecretUpsert { + u.SetExcluded(securitysecret.FieldUpdatedAt) + return u +} + +// SetKey sets the "key" field. +func (u *SecuritySecretUpsert) SetKey(v string) *SecuritySecretUpsert { + u.Set(securitysecret.FieldKey, v) + return u +} + +// UpdateKey sets the "key" field to the value that was provided on create. +func (u *SecuritySecretUpsert) UpdateKey() *SecuritySecretUpsert { + u.SetExcluded(securitysecret.FieldKey) + return u +} + +// SetValue sets the "value" field. +func (u *SecuritySecretUpsert) SetValue(v string) *SecuritySecretUpsert { + u.Set(securitysecret.FieldValue, v) + return u +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *SecuritySecretUpsert) UpdateValue() *SecuritySecretUpsert { + u.SetExcluded(securitysecret.FieldValue) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SecuritySecretUpsertOne) UpdateNewValues() *SecuritySecretUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(securitysecret.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SecuritySecretUpsertOne) Ignore() *SecuritySecretUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SecuritySecretUpsertOne) DoNothing() *SecuritySecretUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SecuritySecretCreate.OnConflict +// documentation for more info. +func (u *SecuritySecretUpsertOne) Update(set func(*SecuritySecretUpsert)) *SecuritySecretUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SecuritySecretUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *SecuritySecretUpsertOne) SetUpdatedAt(v time.Time) *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SecuritySecretUpsertOne) UpdateUpdatedAt() *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetKey sets the "key" field. +func (u *SecuritySecretUpsertOne) SetKey(v string) *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetKey(v) + }) +} + +// UpdateKey sets the "key" field to the value that was provided on create. +func (u *SecuritySecretUpsertOne) UpdateKey() *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateKey() + }) +} + +// SetValue sets the "value" field. +func (u *SecuritySecretUpsertOne) SetValue(v string) *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetValue(v) + }) +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *SecuritySecretUpsertOne) UpdateValue() *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateValue() + }) +} + +// Exec executes the query. +func (u *SecuritySecretUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SecuritySecretCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SecuritySecretUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *SecuritySecretUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *SecuritySecretUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// SecuritySecretCreateBulk is the builder for creating many SecuritySecret entities in bulk. +type SecuritySecretCreateBulk struct { + config + err error + builders []*SecuritySecretCreate + conflict []sql.ConflictOption +} + +// Save creates the SecuritySecret entities in the database. +func (_c *SecuritySecretCreateBulk) Save(ctx context.Context) ([]*SecuritySecret, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*SecuritySecret, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*SecuritySecretMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *SecuritySecretCreateBulk) SaveX(ctx context.Context) []*SecuritySecret { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SecuritySecretCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SecuritySecretCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SecuritySecret.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SecuritySecretUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *SecuritySecretCreateBulk) OnConflict(opts ...sql.ConflictOption) *SecuritySecretUpsertBulk { + _c.conflict = opts + return &SecuritySecretUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SecuritySecretCreateBulk) OnConflictColumns(columns ...string) *SecuritySecretUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SecuritySecretUpsertBulk{ + create: _c, + } +} + +// SecuritySecretUpsertBulk is the builder for "upsert"-ing +// a bulk of SecuritySecret nodes. +type SecuritySecretUpsertBulk struct { + create *SecuritySecretCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SecuritySecretUpsertBulk) UpdateNewValues() *SecuritySecretUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(securitysecret.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SecuritySecretUpsertBulk) Ignore() *SecuritySecretUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SecuritySecretUpsertBulk) DoNothing() *SecuritySecretUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SecuritySecretCreateBulk.OnConflict +// documentation for more info. +func (u *SecuritySecretUpsertBulk) Update(set func(*SecuritySecretUpsert)) *SecuritySecretUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SecuritySecretUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *SecuritySecretUpsertBulk) SetUpdatedAt(v time.Time) *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SecuritySecretUpsertBulk) UpdateUpdatedAt() *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetKey sets the "key" field. +func (u *SecuritySecretUpsertBulk) SetKey(v string) *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetKey(v) + }) +} + +// UpdateKey sets the "key" field to the value that was provided on create. +func (u *SecuritySecretUpsertBulk) UpdateKey() *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateKey() + }) +} + +// SetValue sets the "value" field. +func (u *SecuritySecretUpsertBulk) SetValue(v string) *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetValue(v) + }) +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *SecuritySecretUpsertBulk) UpdateValue() *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateValue() + }) +} + +// Exec executes the query. +func (u *SecuritySecretUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the SecuritySecretCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SecuritySecretCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SecuritySecretUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/securitysecret_delete.go b/backend/ent/securitysecret_delete.go new file mode 100644 index 00000000..66757138 --- /dev/null +++ b/backend/ent/securitysecret_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecretDelete is the builder for deleting a SecuritySecret entity. +type SecuritySecretDelete struct { + config + hooks []Hook + mutation *SecuritySecretMutation +} + +// Where appends a list predicates to the SecuritySecretDelete builder. +func (_d *SecuritySecretDelete) Where(ps ...predicate.SecuritySecret) *SecuritySecretDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *SecuritySecretDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SecuritySecretDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *SecuritySecretDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(securitysecret.Table, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// SecuritySecretDeleteOne is the builder for deleting a single SecuritySecret entity. +type SecuritySecretDeleteOne struct { + _d *SecuritySecretDelete +} + +// Where appends a list predicates to the SecuritySecretDelete builder. +func (_d *SecuritySecretDeleteOne) Where(ps ...predicate.SecuritySecret) *SecuritySecretDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *SecuritySecretDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{securitysecret.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SecuritySecretDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/securitysecret_query.go b/backend/ent/securitysecret_query.go new file mode 100644 index 00000000..fe53adf1 --- /dev/null +++ b/backend/ent/securitysecret_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecretQuery is the builder for querying SecuritySecret entities. +type SecuritySecretQuery struct { + config + ctx *QueryContext + order []securitysecret.OrderOption + inters []Interceptor + predicates []predicate.SecuritySecret + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the SecuritySecretQuery builder. +func (_q *SecuritySecretQuery) Where(ps ...predicate.SecuritySecret) *SecuritySecretQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *SecuritySecretQuery) Limit(limit int) *SecuritySecretQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *SecuritySecretQuery) Offset(offset int) *SecuritySecretQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *SecuritySecretQuery) Unique(unique bool) *SecuritySecretQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *SecuritySecretQuery) Order(o ...securitysecret.OrderOption) *SecuritySecretQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first SecuritySecret entity from the query. +// Returns a *NotFoundError when no SecuritySecret was found. +func (_q *SecuritySecretQuery) First(ctx context.Context) (*SecuritySecret, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{securitysecret.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *SecuritySecretQuery) FirstX(ctx context.Context) *SecuritySecret { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first SecuritySecret ID from the query. +// Returns a *NotFoundError when no SecuritySecret ID was found. +func (_q *SecuritySecretQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{securitysecret.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *SecuritySecretQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single SecuritySecret entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one SecuritySecret entity is found. +// Returns a *NotFoundError when no SecuritySecret entities are found. +func (_q *SecuritySecretQuery) Only(ctx context.Context) (*SecuritySecret, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{securitysecret.Label} + default: + return nil, &NotSingularError{securitysecret.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *SecuritySecretQuery) OnlyX(ctx context.Context) *SecuritySecret { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only SecuritySecret ID in the query. +// Returns a *NotSingularError when more than one SecuritySecret ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *SecuritySecretQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{securitysecret.Label} + default: + err = &NotSingularError{securitysecret.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *SecuritySecretQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of SecuritySecrets. +func (_q *SecuritySecretQuery) All(ctx context.Context) ([]*SecuritySecret, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*SecuritySecret, *SecuritySecretQuery]() + return withInterceptors[[]*SecuritySecret](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *SecuritySecretQuery) AllX(ctx context.Context) []*SecuritySecret { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of SecuritySecret IDs. +func (_q *SecuritySecretQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(securitysecret.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *SecuritySecretQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *SecuritySecretQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*SecuritySecretQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *SecuritySecretQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *SecuritySecretQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *SecuritySecretQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the SecuritySecretQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *SecuritySecretQuery) Clone() *SecuritySecretQuery { + if _q == nil { + return nil + } + return &SecuritySecretQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]securitysecret.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.SecuritySecret{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.SecuritySecret.Query(). +// GroupBy(securitysecret.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *SecuritySecretQuery) GroupBy(field string, fields ...string) *SecuritySecretGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &SecuritySecretGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = securitysecret.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.SecuritySecret.Query(). +// Select(securitysecret.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *SecuritySecretQuery) Select(fields ...string) *SecuritySecretSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &SecuritySecretSelect{SecuritySecretQuery: _q} + sbuild.label = securitysecret.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a SecuritySecretSelect configured with the given aggregations. +func (_q *SecuritySecretQuery) Aggregate(fns ...AggregateFunc) *SecuritySecretSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *SecuritySecretQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !securitysecret.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *SecuritySecretQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*SecuritySecret, error) { + var ( + nodes = []*SecuritySecret{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*SecuritySecret).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &SecuritySecret{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *SecuritySecretQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *SecuritySecretQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, securitysecret.FieldID) + for i := range fields { + if fields[i] != securitysecret.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *SecuritySecretQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(securitysecret.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = securitysecret.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *SecuritySecretQuery) ForUpdate(opts ...sql.LockOption) *SecuritySecretQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *SecuritySecretQuery) ForShare(opts ...sql.LockOption) *SecuritySecretQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// SecuritySecretGroupBy is the group-by builder for SecuritySecret entities. +type SecuritySecretGroupBy struct { + selector + build *SecuritySecretQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *SecuritySecretGroupBy) Aggregate(fns ...AggregateFunc) *SecuritySecretGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *SecuritySecretGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SecuritySecretQuery, *SecuritySecretGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *SecuritySecretGroupBy) sqlScan(ctx context.Context, root *SecuritySecretQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// SecuritySecretSelect is the builder for selecting fields of SecuritySecret entities. +type SecuritySecretSelect struct { + *SecuritySecretQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *SecuritySecretSelect) Aggregate(fns ...AggregateFunc) *SecuritySecretSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *SecuritySecretSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SecuritySecretQuery, *SecuritySecretSelect](ctx, _s.SecuritySecretQuery, _s, _s.inters, v) +} + +func (_s *SecuritySecretSelect) sqlScan(ctx context.Context, root *SecuritySecretQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/securitysecret_update.go b/backend/ent/securitysecret_update.go new file mode 100644 index 00000000..ec3979af --- /dev/null +++ b/backend/ent/securitysecret_update.go @@ -0,0 +1,316 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecretUpdate is the builder for updating SecuritySecret entities. +type SecuritySecretUpdate struct { + config + hooks []Hook + mutation *SecuritySecretMutation +} + +// Where appends a list predicates to the SecuritySecretUpdate builder. +func (_u *SecuritySecretUpdate) Where(ps ...predicate.SecuritySecret) *SecuritySecretUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *SecuritySecretUpdate) SetUpdatedAt(v time.Time) *SecuritySecretUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetKey sets the "key" field. +func (_u *SecuritySecretUpdate) SetKey(v string) *SecuritySecretUpdate { + _u.mutation.SetKey(v) + return _u +} + +// SetNillableKey sets the "key" field if the given value is not nil. +func (_u *SecuritySecretUpdate) SetNillableKey(v *string) *SecuritySecretUpdate { + if v != nil { + _u.SetKey(*v) + } + return _u +} + +// SetValue sets the "value" field. +func (_u *SecuritySecretUpdate) SetValue(v string) *SecuritySecretUpdate { + _u.mutation.SetValue(v) + return _u +} + +// SetNillableValue sets the "value" field if the given value is not nil. +func (_u *SecuritySecretUpdate) SetNillableValue(v *string) *SecuritySecretUpdate { + if v != nil { + _u.SetValue(*v) + } + return _u +} + +// Mutation returns the SecuritySecretMutation object of the builder. +func (_u *SecuritySecretUpdate) Mutation() *SecuritySecretMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *SecuritySecretUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SecuritySecretUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *SecuritySecretUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SecuritySecretUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *SecuritySecretUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := securitysecret.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *SecuritySecretUpdate) check() error { + if v, ok := _u.mutation.Key(); ok { + if err := securitysecret.KeyValidator(v); err != nil { + return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)} + } + } + if v, ok := _u.mutation.Value(); ok { + if err := securitysecret.ValueValidator(v); err != nil { + return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)} + } + } + return nil +} + +func (_u *SecuritySecretUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Key(); ok { + _spec.SetField(securitysecret.FieldKey, field.TypeString, value) + } + if value, ok := _u.mutation.Value(); ok { + _spec.SetField(securitysecret.FieldValue, field.TypeString, value) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{securitysecret.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// SecuritySecretUpdateOne is the builder for updating a single SecuritySecret entity. +type SecuritySecretUpdateOne struct { + config + fields []string + hooks []Hook + mutation *SecuritySecretMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *SecuritySecretUpdateOne) SetUpdatedAt(v time.Time) *SecuritySecretUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetKey sets the "key" field. +func (_u *SecuritySecretUpdateOne) SetKey(v string) *SecuritySecretUpdateOne { + _u.mutation.SetKey(v) + return _u +} + +// SetNillableKey sets the "key" field if the given value is not nil. +func (_u *SecuritySecretUpdateOne) SetNillableKey(v *string) *SecuritySecretUpdateOne { + if v != nil { + _u.SetKey(*v) + } + return _u +} + +// SetValue sets the "value" field. +func (_u *SecuritySecretUpdateOne) SetValue(v string) *SecuritySecretUpdateOne { + _u.mutation.SetValue(v) + return _u +} + +// SetNillableValue sets the "value" field if the given value is not nil. +func (_u *SecuritySecretUpdateOne) SetNillableValue(v *string) *SecuritySecretUpdateOne { + if v != nil { + _u.SetValue(*v) + } + return _u +} + +// Mutation returns the SecuritySecretMutation object of the builder. +func (_u *SecuritySecretUpdateOne) Mutation() *SecuritySecretMutation { + return _u.mutation +} + +// Where appends a list predicates to the SecuritySecretUpdate builder. +func (_u *SecuritySecretUpdateOne) Where(ps ...predicate.SecuritySecret) *SecuritySecretUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *SecuritySecretUpdateOne) Select(field string, fields ...string) *SecuritySecretUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated SecuritySecret entity. +func (_u *SecuritySecretUpdateOne) Save(ctx context.Context) (*SecuritySecret, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SecuritySecretUpdateOne) SaveX(ctx context.Context) *SecuritySecret { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *SecuritySecretUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SecuritySecretUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *SecuritySecretUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := securitysecret.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *SecuritySecretUpdateOne) check() error { + if v, ok := _u.mutation.Key(); ok { + if err := securitysecret.KeyValidator(v); err != nil { + return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)} + } + } + if v, ok := _u.mutation.Value(); ok { + if err := securitysecret.ValueValidator(v); err != nil { + return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)} + } + } + return nil +} + +func (_u *SecuritySecretUpdateOne) sqlSave(ctx context.Context) (_node *SecuritySecret, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SecuritySecret.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, securitysecret.FieldID) + for _, f := range fields { + if !securitysecret.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != securitysecret.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Key(); ok { + _spec.SetField(securitysecret.FieldKey, field.TypeString, value) + } + if value, ok := _u.mutation.Value(); ok { + _spec.SetField(securitysecret.FieldValue, field.TypeString, value) + } + _node = &SecuritySecret{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{securitysecret.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/tx.go b/backend/ent/tx.go index 45d83428..cd3b2296 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -28,6 +28,8 @@ type Tx struct { ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient + // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. + IdempotencyRecord *IdempotencyRecordClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -36,6 +38,8 @@ type Tx struct { Proxy *ProxyClient // RedeemCode is the client for interacting with the RedeemCode builders. RedeemCode *RedeemCodeClient + // SecuritySecret is the client for interacting with the SecuritySecret builders. + SecuritySecret *SecuritySecretClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. @@ -190,10 +194,12 @@ func (tx *Tx) init() { tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.Group = NewGroupClient(tx.config) + tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.Proxy = NewProxyClient(tx.config) tx.RedeemCode = NewRedeemCodeClient(tx.config) + tx.SecuritySecret = NewSecuritySecretClient(tx.config) tx.Setting = NewSettingClient(tx.config) tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config) tx.UsageLog = NewUsageLogClient(tx.config) diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index 81c466b4..f6968d0d 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -80,6 +80,10 @@ type UsageLog struct { ImageCount int `json:"image_count,omitempty"` // ImageSize holds the value of the "image_size" field. ImageSize *string `json:"image_size,omitempty"` + // MediaType holds the value of the "media_type" field. + MediaType *string `json:"media_type,omitempty"` + // CacheTTLOverridden holds the value of the "cache_ttl_overridden" field. + CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. @@ -165,13 +169,13 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case usagelog.FieldStream: + case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden: values[i] = new(sql.NullBool) case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier: values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -378,6 +382,19 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.ImageSize = new(string) *_m.ImageSize = value.String } + case usagelog.FieldMediaType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field media_type", values[i]) + } else if value.Valid { + _m.MediaType = new(string) + *_m.MediaType = value.String + } + case usagelog.FieldCacheTTLOverridden: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i]) + } else if value.Valid { + _m.CacheTTLOverridden = value.Bool + } case usagelog.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -548,6 +565,14 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.MediaType; v != nil { + builder.WriteString("media_type=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("cache_ttl_overridden=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden)) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteByte(')') diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index 980f1e58..ba97b843 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -72,6 +72,10 @@ const ( FieldImageCount = "image_count" // FieldImageSize holds the string denoting the image_size field in the database. FieldImageSize = "image_size" + // FieldMediaType holds the string denoting the media_type field in the database. + FieldMediaType = "media_type" + // FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database. + FieldCacheTTLOverridden = "cache_ttl_overridden" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // EdgeUser holds the string denoting the user edge name in mutations. @@ -155,6 +159,8 @@ var Columns = []string{ FieldIPAddress, FieldImageCount, FieldImageSize, + FieldMediaType, + FieldCacheTTLOverridden, FieldCreatedAt, } @@ -211,6 +217,10 @@ var ( DefaultImageCount int // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. ImageSizeValidator func(string) error + // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. + MediaTypeValidator func(string) error + // DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field. + DefaultCacheTTLOverridden bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time ) @@ -368,6 +378,16 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImageSize, opts...).ToFunc() } +// ByMediaType orders the results by the media_type field. +func ByMediaType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMediaType, opts...).ToFunc() +} + +// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field. +func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index 28e2ab4c..af960335 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -200,6 +200,16 @@ func ImageSize(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v)) } +// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ. +func MediaType(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) +} + +// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ. +func CacheTTLOverridden(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) @@ -1440,6 +1450,91 @@ func ImageSizeContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v)) } +// MediaTypeEQ applies the EQ predicate on the "media_type" field. +func MediaTypeEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) +} + +// MediaTypeNEQ applies the NEQ predicate on the "media_type" field. +func MediaTypeNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v)) +} + +// MediaTypeIn applies the In predicate on the "media_type" field. +func MediaTypeIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...)) +} + +// MediaTypeNotIn applies the NotIn predicate on the "media_type" field. +func MediaTypeNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...)) +} + +// MediaTypeGT applies the GT predicate on the "media_type" field. +func MediaTypeGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldMediaType, v)) +} + +// MediaTypeGTE applies the GTE predicate on the "media_type" field. +func MediaTypeGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v)) +} + +// MediaTypeLT applies the LT predicate on the "media_type" field. +func MediaTypeLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldMediaType, v)) +} + +// MediaTypeLTE applies the LTE predicate on the "media_type" field. +func MediaTypeLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v)) +} + +// MediaTypeContains applies the Contains predicate on the "media_type" field. +func MediaTypeContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldMediaType, v)) +} + +// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field. +func MediaTypeHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v)) +} + +// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field. +func MediaTypeHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v)) +} + +// MediaTypeIsNil applies the IsNil predicate on the "media_type" field. +func MediaTypeIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldMediaType)) +} + +// MediaTypeNotNil applies the NotNil predicate on the "media_type" field. +func MediaTypeNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldMediaType)) +} + +// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field. +func MediaTypeEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v)) +} + +// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field. +func MediaTypeContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v)) +} + +// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field. +func CacheTTLOverriddenEQ(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) +} + +// CacheTTLOverriddenNEQ applies the NEQ predicate on the "cache_ttl_overridden" field. +func CacheTTLOverriddenNEQ(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheTTLOverridden, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index a17d6507..e0285a5e 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -393,6 +393,34 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate { return _c } +// SetMediaType sets the "media_type" field. +func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate { + _c.mutation.SetMediaType(v) + return _c +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate { + if v != nil { + _c.SetMediaType(*v) + } + return _c +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate { + _c.mutation.SetCacheTTLOverridden(v) + return _c +} + +// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheTTLOverridden(v *bool) *UsageLogCreate { + if v != nil { + _c.SetCacheTTLOverridden(*v) + } + return _c +} + // SetCreatedAt sets the "created_at" field. func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate { _c.mutation.SetCreatedAt(v) @@ -531,6 +559,10 @@ func (_c *UsageLogCreate) defaults() { v := usagelog.DefaultImageCount _c.mutation.SetImageCount(v) } + if _, ok := _c.mutation.CacheTTLOverridden(); !ok { + v := usagelog.DefaultCacheTTLOverridden + _c.mutation.SetCacheTTLOverridden(v) + } if _, ok := _c.mutation.CreatedAt(); !ok { v := usagelog.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) @@ -627,6 +659,14 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _c.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } + if _, ok := _c.mutation.CacheTTLOverridden(); !ok { + return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)} + } if _, ok := _c.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)} } @@ -762,6 +802,14 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldImageSize, field.TypeString, value) _node.ImageSize = &value } + if value, ok := _c.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + _node.MediaType = &value + } + if value, ok := _c.mutation.CacheTTLOverridden(); ok { + _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) + _node.CacheTTLOverridden = value + } if value, ok := _c.mutation.CreatedAt(); ok { _spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -1407,6 +1455,36 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert { return u } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert { + u.Set(usagelog.FieldMediaType, v) + return u +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldMediaType) + return u +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert { + u.SetNull(usagelog.FieldMediaType) + return u +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert { + u.Set(usagelog.FieldCacheTTLOverridden, v) + return u +} + +// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheTTLOverridden() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheTTLOverridden) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2040,6 +2118,41 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne { }) } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetMediaType(v) + }) +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateMediaType() + }) +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearMediaType() + }) +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheTTLOverridden(v) + }) +} + +// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheTTLOverridden() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheTTLOverridden() + }) +} + // Exec executes the query. func (u *UsageLogUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2839,6 +2952,41 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk { }) } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetMediaType(v) + }) +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateMediaType() + }) +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearMediaType() + }) +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheTTLOverridden(v) + }) +} + +// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheTTLOverridden() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheTTLOverridden() + }) +} + // Exec executes the query. func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index 571a7b3c..b46e5b56 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -612,6 +612,40 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate { return _u } +// SetMediaType sets the "media_type" field. +func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate { + _u.mutation.SetMediaType(v) + return _u +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate { + if v != nil { + _u.SetMediaType(*v) + } + return _u +} + +// ClearMediaType clears the value of the "media_type" field. +func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate { + _u.mutation.ClearMediaType() + return _u +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate { + _u.mutation.SetCacheTTLOverridden(v) + return _u +} + +// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdate { + if v != nil { + _u.SetCacheTTLOverridden(*v) + } + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate { return _u.SetUserID(v.ID) @@ -726,6 +760,11 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _u.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -894,6 +933,15 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + } + if _u.mutation.MediaTypeCleared() { + _spec.ClearField(usagelog.FieldMediaType, field.TypeString) + } + if value, ok := _u.mutation.CacheTTLOverridden(); ok { + _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -1639,6 +1687,40 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne { return _u } +// SetMediaType sets the "media_type" field. +func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne { + _u.mutation.SetMediaType(v) + return _u +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetMediaType(*v) + } + return _u +} + +// ClearMediaType clears the value of the "media_type" field. +func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne { + _u.mutation.ClearMediaType() + return _u +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne { + _u.mutation.SetCacheTTLOverridden(v) + return _u +} + +// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheTTLOverridden(*v) + } + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne { return _u.SetUserID(v.ID) @@ -1766,6 +1848,11 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _u.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -1951,6 +2038,15 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + } + if _u.mutation.MediaTypeCleared() { + _spec.ClearField(usagelog.FieldMediaType, field.TypeString) + } + if value, ok := _u.mutation.CacheTTLOverridden(); ok { + _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/ent/user.go b/backend/ent/user.go index 2435aa1b..b3f933f6 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -45,6 +45,10 @@ type User struct { TotpEnabled bool `json:"totp_enabled,omitempty"` // TotpEnabledAt holds the value of the "totp_enabled_at" field. TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` + // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` + // SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field. + SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -177,7 +181,7 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case user.FieldBalance: values[i] = new(sql.NullFloat64) - case user.FieldID, user.FieldConcurrency: + case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes: values[i] = new(sql.NullInt64) case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted: values[i] = new(sql.NullString) @@ -291,6 +295,18 @@ func (_m *User) assignValues(columns []string, values []any) error { _m.TotpEnabledAt = new(time.Time) *_m.TotpEnabledAt = value.Time } + case user.FieldSoraStorageQuotaBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i]) + } else if value.Valid { + _m.SoraStorageQuotaBytes = value.Int64 + } + case user.FieldSoraStorageUsedBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i]) + } else if value.Valid { + _m.SoraStorageUsedBytes = value.Int64 + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -424,6 +440,12 @@ func (_m *User) String() string { builder.WriteString("totp_enabled_at=") builder.WriteString(v.Format(time.ANSIC)) } + builder.WriteString(", ") + builder.WriteString("sora_storage_quota_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes)) + builder.WriteString(", ") + builder.WriteString("sora_storage_used_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index ae9418ff..155b9160 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -43,6 +43,10 @@ const ( FieldTotpEnabled = "totp_enabled" // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. FieldTotpEnabledAt = "totp_enabled_at" + // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database. + FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes" + // FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database. + FieldSoraStorageUsedBytes = "sora_storage_used_bytes" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -152,6 +156,8 @@ var Columns = []string{ FieldTotpSecretEncrypted, FieldTotpEnabled, FieldTotpEnabledAt, + FieldSoraStorageQuotaBytes, + FieldSoraStorageUsedBytes, } var ( @@ -208,6 +214,10 @@ var ( DefaultNotes string // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. DefaultTotpEnabled bool + // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field. + DefaultSoraStorageQuotaBytes int64 + // DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field. + DefaultSoraStorageUsedBytes int64 ) // OrderOption defines the ordering options for the User queries. @@ -288,6 +298,16 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() } +// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field. +func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc() +} + +// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field. +func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 1de61037..e26afcf3 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -125,6 +125,16 @@ func TotpEnabledAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) } +// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ. +func SoraStorageQuotaBytes(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ. +func SoraStorageUsedBytes(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -860,6 +870,86 @@ func TotpEnabledAtNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) } +// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesEQ(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNEQ(v int64) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGT(v int64) predicate.User { + return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGTE(v int64) predicate.User { + return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLT(v int64) predicate.User { + return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLTE(v int64) predicate.User { + return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesEQ(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesNEQ(v int64) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...)) +} + +// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...)) +} + +// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesGT(v int64) predicate.User { + return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesGTE(v int64) predicate.User { + return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesLT(v int64) predicate.User { + return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesLTE(v int64) predicate.User { + return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index f862a580..df0c6bcc 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -210,6 +210,34 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { return _c } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate { + _c.mutation.SetSoraStorageQuotaBytes(v) + return _c +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate { + if v != nil { + _c.SetSoraStorageQuotaBytes(*v) + } + return _c +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate { + _c.mutation.SetSoraStorageUsedBytes(v) + return _c +} + +// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. +func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate { + if v != nil { + _c.SetSoraStorageUsedBytes(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -424,6 +452,14 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotpEnabled _c.mutation.SetTotpEnabled(v) } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + v := user.DefaultSoraStorageQuotaBytes + _c.mutation.SetSoraStorageQuotaBytes(v) + } + if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok { + v := user.DefaultSoraStorageUsedBytes + _c.mutation.SetSoraStorageUsedBytes(v) + } return nil } @@ -487,6 +523,12 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotpEnabled(); !ok { return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)} + } + if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok { + return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)} + } return nil } @@ -570,6 +612,14 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _node.TotpEnabledAt = &value } + if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + _node.SoraStorageQuotaBytes = value + } + if value, ok := _c.mutation.SoraStorageUsedBytes(); ok { + _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + _node.SoraStorageUsedBytes = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -956,6 +1006,42 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { return u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert { + u.Set(user.FieldSoraStorageQuotaBytes, v) + return u +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert { + u.SetExcluded(user.FieldSoraStorageQuotaBytes) + return u +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert { + u.Add(user.FieldSoraStorageQuotaBytes, v) + return u +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert { + u.Set(user.FieldSoraStorageUsedBytes, v) + return u +} + +// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. +func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert { + u.SetExcluded(user.FieldSoraStorageUsedBytes) + return u +} + +// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. +func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert { + u.Add(user.FieldSoraStorageUsedBytes, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1218,6 +1304,48 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { }) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageUsedBytes(v) + }) +} + +// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. +func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageUsedBytes(v) + }) +} + +// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageUsedBytes() + }) +} + // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1646,6 +1774,48 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { }) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageUsedBytes(v) + }) +} + +// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. +func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageUsedBytes(v) + }) +} + +// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageUsedBytes() + }) +} + // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 80222c92..f71f0cad 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -242,6 +242,48 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { return _u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate { + _u.mutation.ResetSoraStorageUsedBytes() + _u.mutation.SetSoraStorageUsedBytes(v) + return _u +} + +// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate { + if v != nil { + _u.SetSoraStorageUsedBytes(*v) + } + return _u +} + +// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field. +func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate { + _u.mutation.AddSoraStorageUsedBytes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -709,6 +751,18 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.SoraStorageUsedBytes(); ok { + _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok { + _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1352,6 +1406,48 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { return _u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne { + _u.mutation.ResetSoraStorageUsedBytes() + _u.mutation.SetSoraStorageUsedBytes(v) + return _u +} + +// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne { + if v != nil { + _u.SetSoraStorageUsedBytes(*v) + } + return _u +} + +// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field. +func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne { + _u.mutation.AddSoraStorageUsedBytes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1849,6 +1945,18 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.SoraStorageUsedBytes(); ok { + _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok { + _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/go.mod b/backend/go.mod index 6916057f..a34c9fff 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -5,6 +5,13 @@ go 1.25.7 require ( entgo.io/ent v0.14.5 github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/DouDOU-start/go-sora2api v1.1.0 + github.com/alitto/pond/v2 v2.6.2 + github.com/aws/aws-sdk-go-v2/config v1.32.10 + github.com/aws/aws-sdk-go-v2/credentials v1.19.10 + github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 + github.com/cespare/xxhash/v2 v2.3.0 + github.com/coder/websocket v1.8.14 github.com/dgraph-io/ristretto v0.2.0 github.com/gin-gonic/gin v1.9.1 github.com/golang-jwt/jwt/v5 v5.2.2 @@ -13,9 +20,10 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/imroc/req/v3 v3.57.0 github.com/lib/pq v1.10.9 + github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pquerna/otp v1.5.0 github.com/redis/go-redis/v9 v9.17.2 - github.com/refraction-networking/utls v1.8.1 + github.com/refraction-networking/utls v1.8.2 github.com/robfig/cron/v3 v3.0.1 github.com/shirou/gopsutil/v4 v4.25.6 github.com/spf13/viper v1.18.2 @@ -25,10 +33,14 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/zeromicro/go-zero v1.9.4 - golang.org/x/crypto v0.47.0 + go.uber.org/zap v1.24.0 + golang.org/x/crypto v0.48.0 golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 - golang.org/x/term v0.39.0 + golang.org/x/term v0.40.0 + google.golang.org/grpc v1.75.1 + google.golang.org/protobuf v1.36.10 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.44.3 ) @@ -41,11 +53,33 @@ require ( github.com/agext/levenshtein v1.2.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect + github.com/aws/smithy-go v1.24.1 // indirect + github.com/bdandy/go-errors v1.2.2 // indirect + github.com/bdandy/go-socks4 v1.2.3 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect + github.com/bogdanfinn/fhttp v0.6.8 // indirect + github.com/bogdanfinn/quic-go-utls v1.0.9-utls // indirect + github.com/bogdanfinn/tls-client v1.14.0 // indirect + github.com/bogdanfinn/utls v1.7.7-barnius // indirect + github.com/bogdanfinn/websocket v1.5.5-barnius // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect @@ -75,6 +109,7 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect @@ -119,6 +154,7 @@ require ( github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 // indirect github.com/testcontainers/testcontainers-go v0.40.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect @@ -133,16 +169,17 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect go.opentelemetry.io/otel v1.37.0 // indirect go.opentelemetry.io/otel/metric v1.37.0 // indirect - go.opentelemetry.io/otel/sdk v1.37.0 // indirect go.opentelemetry.io/otel/trace v1.37.0 // indirect go.uber.org/atomic v1.10.0 // indirect go.uber.org/automaxprocs v1.6.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect - golang.org/x/mod v0.31.0 // indirect - golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.33.0 // indirect + golang.org/x/mod v0.32.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect + golang.org/x/tools v0.41.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 171995c7..32e389a7 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -10,16 +10,74 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/DouDOU-start/go-sora2api v1.1.0 h1:PxWiukK77StiHxEngOFwT1rKUn9oTAJJTl07wQUXwiU= +github.com/DouDOU-start/go-sora2api v1.1.0/go.mod h1:dcwpethoKfAsMWskDD9iGgc/3yox2tkthPLSMVGnhkE= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= +github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw= +github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= +github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls= +github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c= +github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI= +github.com/aws/aws-sdk-go-v2/config v1.32.10/go.mod h1:2rUIOnA2JaiqYmSKYmRJlcMWy6qTj1vuRFscppSBMcw= +github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8= +github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 h1:Ii4s+Sq3yDfaMLpjrJsqD6SmG/Wq/P5L/hw2qa78UAY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18/go.mod h1:6x81qnY++ovptLE6nWQeWrpXxbnlIex+4H4eYYGcqfc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 h1:eZioDaZGJ0tMM4gzmkNIO2aAoQd+je7Ug7TkvAzlmkU= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18/go.mod h1:CCXwUKAJdoWr6/NcxZ+zsiPr6oH/Q5aTooRGYieAyj4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 h1:CeY9LUdur+Dxoeldqoun6y4WtJ3RQtzk0JMP2gfUay0= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5/go.mod h1:AZLZf2fMaahW5s/wMRciu1sYbdsikT/UHwbUjOdEVTc= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 h1:fJvQ5mIBVfKtiyx0AHY6HeWcRX5LGANLpq8SVR+Uazs= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10/go.mod h1:Kzm5e6OmNH8VMkgK9t+ry5jEih4Y8whqs+1hrkxim1I= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 h1:LTRCYFlnnKFlKsyIQxKhJuDuA3ZkrDQMRYm6rXiHlLY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18/go.mod h1:XhwkgGG6bHSd00nO/mexWTcTjgd6PjuvWQMqSn2UaEk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 h1:/A/xDuZAVD2BpsS2fftFRo/NoEKQJ8YTnJDEHBy2Gtg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18/go.mod h1:hWe9b4f+djUQGmyiGEeOnZv69dtMSgpDRIvNMvuvzvY= +github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 h1:M1A9AjcFwlxTLuf0Faj88L8Iqw0n/AJHjpZTQzMMsSc= +github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2/go.mod h1:KsdTV6Q9WKUZm2mNJnUFmIoXfZux91M3sr/a4REX8e0= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 h1:MzORe+J94I+hYu2a6XmV5yC9huoTv8NRcCrUNedDypQ= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.6/go.mod h1:hXzcHLARD7GeWnifd8j9RWqtfIgxj4/cAtIVIK7hg8g= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 h1:7oGD8KPfBOJGXiCoRKrrrQkbvCp8N++u36hrLMPey6o= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.11/go.mod h1:0DO9B5EUJQlIDif+XJRWCljZRKsAFKh3gpFz7UnDtOo= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWAXLGFIizeqkdkKgRlJwWc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs= +github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0= +github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q= +github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM= +github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic= +github.com/bdandy/go-socks4 v1.2.3/go.mod h1:98kiVFgpdogR8aIGLWLvjDVZ8XcKPsSI/ypGrO+bqHI= +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0= github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE= +github.com/bogdanfinn/fhttp v0.6.8 h1:LiQyHOY3i0QoxxNB7nq27/nGNNbtPj0fuBPozhR7Ws4= +github.com/bogdanfinn/fhttp v0.6.8/go.mod h1:A+EKDzMx2hb4IUbMx4TlkoHnaJEiLl8r/1Ss1Y+5e5M= +github.com/bogdanfinn/quic-go-utls v1.0.9-utls h1:tV6eDEiRbRCcepALSzxR94JUVD3N3ACIiRLgyc2Ep8s= +github.com/bogdanfinn/quic-go-utls v1.0.9-utls/go.mod h1:aHph9B9H9yPOt5xnhWKSOum27DJAqpiHzwX+gjvaXcg= +github.com/bogdanfinn/tls-client v1.14.0 h1:vyk7Cn4BIvLAGVuMfb0tP22OqogfO1lYamquQNEZU1A= +github.com/bogdanfinn/tls-client v1.14.0/go.mod h1:LsU6mXVn8MOFDwTkyRfI7V1BZM1p0wf2ZfZsICW/1fM= +github.com/bogdanfinn/utls v1.7.7-barnius h1:OuJ497cc7F3yKNVHRsYPQdGggmk5x6+V5ZlrCR7fOLU= +github.com/bogdanfinn/utls v1.7.7-barnius/go.mod h1:aAK1VZQlpKZClF1WEQeq6kyclbkPq4hz6xTbB5xSlmg= +github.com/bogdanfinn/websocket v1.5.5-barnius h1:bY+qnxpai1qe7Jmjx+Sds/cmOSpuuLoR8x61rWltjOI= +github.com/bogdanfinn/websocket v1.5.5-barnius/go.mod h1:gvvEw6pTKHb7yOiFvIfAFTStQWyrm25BMVCTj5wRSsI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -36,6 +94,12 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= +github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= @@ -107,6 +171,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -116,6 +182,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -170,6 +238,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -203,10 +273,14 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -226,8 +300,8 @@ github.com/quic-go/quic-go v0.57.1 h1:25KAAR9QR8KZrCZRThWMKVAwGoiHIrNbT72ULHTuI1 github.com/quic-go/quic-go v0.57.1/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s= github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI= github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= -github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo= -github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= +github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo= +github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= @@ -252,6 +326,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= @@ -273,6 +349,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc= +github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5/go.mod h1:2JjD2zLQYH5HO74y5+aE3remJQvl6q4Sn6aWA2wD1Ng= github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU= github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY= github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk= @@ -320,6 +398,8 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= @@ -328,25 +408,32 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= +go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= -golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= -golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= +golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -354,17 +441,22 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= -golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= -golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk= @@ -379,6 +471,8 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 91437ba8..c1f54ab6 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -5,7 +5,7 @@ import ( "crypto/rand" "encoding/hex" "fmt" - "log" + "log/slog" "net/url" "os" "strings" @@ -19,10 +19,25 @@ const ( RunModeSimple = "simple" ) +// 使用量记录队列溢出策略 +const ( + UsageRecordOverflowPolicyDrop = "drop" + UsageRecordOverflowPolicySample = "sample" + UsageRecordOverflowPolicySync = "sync" +) + // DefaultCSPPolicy is the default Content-Security-Policy with nonce support // __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" +// UMQ(用户消息队列)模式常量 +const ( + // UMQModeSerialize: 账号级串行锁 + RPM 自适应延迟 + UMQModeSerialize = "serialize" + // UMQModeThrottle: 仅 RPM 自适应前置延迟,不阻塞并发 + UMQModeThrottle = "throttle" +) + // 连接池隔离策略常量 // 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗 const ( @@ -38,31 +53,68 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - Ops OpsConfig `mapstructure:"ops"` - JWT JWTConfig `mapstructure:"jwt"` - Totp TotpConfig `mapstructure:"totp"` - LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` - Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` - DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` - UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + Log LogConfig `mapstructure:"log"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + Ops OpsConfig `mapstructure:"ops"` + JWT JWTConfig `mapstructure:"jwt"` + Totp TotpConfig `mapstructure:"totp"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"` + SubscriptionMaintenance SubscriptionMaintenanceConfig `mapstructure:"subscription_maintenance"` + Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` + DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` + UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + Sora SoraConfig `mapstructure:"sora"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` + Idempotency IdempotencyConfig `mapstructure:"idempotency"` +} + +type LogConfig struct { + Level string `mapstructure:"level"` + Format string `mapstructure:"format"` + ServiceName string `mapstructure:"service_name"` + Environment string `mapstructure:"env"` + Caller bool `mapstructure:"caller"` + StacktraceLevel string `mapstructure:"stacktrace_level"` + Output LogOutputConfig `mapstructure:"output"` + Rotation LogRotationConfig `mapstructure:"rotation"` + Sampling LogSamplingConfig `mapstructure:"sampling"` +} + +type LogOutputConfig struct { + ToStdout bool `mapstructure:"to_stdout"` + ToFile bool `mapstructure:"to_file"` + FilePath string `mapstructure:"file_path"` +} + +type LogRotationConfig struct { + MaxSizeMB int `mapstructure:"max_size_mb"` + MaxBackups int `mapstructure:"max_backups"` + MaxAgeDays int `mapstructure:"max_age_days"` + Compress bool `mapstructure:"compress"` + LocalTime bool `mapstructure:"local_time"` +} + +type LogSamplingConfig struct { + Enabled bool `mapstructure:"enabled"` + Initial int `mapstructure:"initial"` + Thereafter int `mapstructure:"thereafter"` } type GeminiConfig struct { @@ -94,6 +146,25 @@ type UpdateConfig struct { ProxyURL string `mapstructure:"proxy_url"` } +type IdempotencyConfig struct { + // ObserveOnly 为 true 时处于观察期:未携带 Idempotency-Key 的请求继续放行。 + ObserveOnly bool `mapstructure:"observe_only"` + // DefaultTTLSeconds 关键写接口的幂等记录默认 TTL(秒)。 + DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"` + // SystemOperationTTLSeconds 系统操作接口的幂等记录 TTL(秒)。 + SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"` + // ProcessingTimeoutSeconds processing 状态锁超时(秒)。 + ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"` + // FailedRetryBackoffSeconds 失败退避窗口(秒)。 + FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"` + // MaxStoredResponseLen 持久化响应体最大长度(字节)。 + MaxStoredResponseLen int `mapstructure:"max_stored_response_len"` + // CleanupIntervalSeconds 过期记录清理周期(秒)。 + CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"` + // CleanupBatchSize 每次清理的最大记录数。 + CleanupBatchSize int `mapstructure:"cleanup_batch_size"` +} + type LinuxDoConnectConfig struct { Enabled bool `mapstructure:"enabled"` ClientID string `mapstructure:"client_id"` @@ -126,6 +197,8 @@ type TokenRefreshConfig struct { MaxRetries int `mapstructure:"max_retries"` // 重试退避基础时间(秒) RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` + // 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭) + SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"` } type PricingConfig struct { @@ -147,6 +220,7 @@ type ServerConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` Mode string `mapstructure:"mode"` // debug/release + FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL,用于生成邮件中的外部链接 ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) @@ -173,6 +247,7 @@ type SecurityConfig struct { URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"` ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"` CSP CSPConfig `mapstructure:"csp"` + ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"` ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"` } @@ -197,6 +272,17 @@ type CSPConfig struct { Policy string `mapstructure:"policy"` } +type ProxyFallbackConfig struct { + // AllowDirectOnError 当辅助服务的代理初始化失败时是否允许回退直连。 + // 仅影响以下非 AI 账号连接的辅助服务: + // - GitHub Release 更新检查 + // - 定价数据拉取 + // 不影响 AI 账号网关连接(Claude/OpenAI/Gemini/Antigravity), + // 这些关键路径的代理失败始终返回错误,不会回退直连。 + // 默认 false:避免因代理配置错误导致服务器真实 IP 泄露。 + AllowDirectOnError bool `mapstructure:"allow_direct_on_error"` +} + type ProxyProbeConfig struct { InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证 } @@ -217,6 +303,59 @@ type ConcurrencyConfig struct { PingInterval int `mapstructure:"ping_interval"` } +// SoraConfig 直连 Sora 配置 +type SoraConfig struct { + Client SoraClientConfig `mapstructure:"client"` + Storage SoraStorageConfig `mapstructure:"storage"` +} + +// SoraClientConfig 直连 Sora 客户端配置 +type SoraClientConfig struct { + BaseURL string `mapstructure:"base_url"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + MaxRetries int `mapstructure:"max_retries"` + CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"` + PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` + MaxPollAttempts int `mapstructure:"max_poll_attempts"` + RecentTaskLimit int `mapstructure:"recent_task_limit"` + RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` + Debug bool `mapstructure:"debug"` + UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"` + Headers map[string]string `mapstructure:"headers"` + UserAgent string `mapstructure:"user_agent"` + DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` + CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"` +} + +// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置 +type SoraCurlCFFISidecarConfig struct { + Enabled bool `mapstructure:"enabled"` + BaseURL string `mapstructure:"base_url"` + Impersonate string `mapstructure:"impersonate"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"` + SessionTTLSeconds int `mapstructure:"session_ttl_seconds"` +} + +// SoraStorageConfig 媒体存储配置 +type SoraStorageConfig struct { + Type string `mapstructure:"type"` + LocalPath string `mapstructure:"local_path"` + FallbackToUpstream bool `mapstructure:"fallback_to_upstream"` + MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"` + DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"` + MaxDownloadBytes int64 `mapstructure:"max_download_bytes"` + Debug bool `mapstructure:"debug"` + Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"` +} + +// SoraStorageCleanupConfig 媒体清理配置 +type SoraStorageCleanupConfig struct { + Enabled bool `mapstructure:"enabled"` + Schedule string `mapstructure:"schedule"` + RetentionDays int `mapstructure:"retention_days"` +} + // GatewayConfig API网关相关配置 type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 @@ -224,8 +363,22 @@ type GatewayConfig struct { ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` // 请求体最大字节数,用于网关请求体大小限制 MaxBodySize int64 `mapstructure:"max_body_size"` + // 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大 + UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"` + // 代理探测响应体读取上限(字节) + ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"` + // Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销) + GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"` // ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy) ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` + // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 + // 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。 + ForceCodexCLI bool `mapstructure:"force_codex_cli"` + // OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头 + // 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。 + OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` + // OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP) + OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"` // HTTP 上游连接池配置(性能优化:支持高并发场景调优) // MaxIdleConns: 所有主机的最大空闲连接总数 @@ -271,6 +424,24 @@ type GatewayConfig struct { // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) FailoverOn400 bool `mapstructure:"failover_on_400"` + // Sora 专用配置 + // SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size) + SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"` + // SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制) + SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"` + // SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制) + SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"` + // SoraStreamMode: stream 强制策略(force/error) + SoraStreamMode string `mapstructure:"sora_stream_mode"` + // SoraModelFilters: 模型列表过滤配置 + SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"` + // SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key + SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"` + // SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名) + SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"` + // SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用) + SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"` + // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限) MaxAccountSwitches int `mapstructure:"max_account_switches"` // Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格) @@ -284,6 +455,194 @@ type GatewayConfig struct { // TLSFingerprint: TLS指纹伪装配置 TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` + + // UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker) + UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"` + + // UserGroupRateCacheTTLSeconds: 用户分组倍率热路径缓存 TTL(秒) + UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"` + // ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL(秒) + ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"` + + // UserMessageQueue: 用户消息串行队列配置 + // 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟 + UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"` +} + +// UserMessageQueueConfig 用户消息串行队列配置 +// 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送 +type UserMessageQueueConfig struct { + // Mode: 模式选择 + // "serialize" = 账号级串行锁 + RPM 自适应延迟 + // "throttle" = 仅 RPM 自适应前置延迟,不阻塞并发 + // "" = 禁用(默认) + Mode string `mapstructure:"mode"` + // Enabled: 已废弃,仅向后兼容(等同于 mode: "serialize") + Enabled bool `mapstructure:"enabled"` + // LockTTLMs: 串行锁 TTL(毫秒),应大于最长请求时间 + LockTTLMs int `mapstructure:"lock_ttl_ms"` + // WaitTimeoutMs: 等待获取锁的超时时间(毫秒) + WaitTimeoutMs int `mapstructure:"wait_timeout_ms"` + // MinDelayMs: RPM 自适应延迟下限(毫秒) + MinDelayMs int `mapstructure:"min_delay_ms"` + // MaxDelayMs: RPM 自适应延迟上限(毫秒) + MaxDelayMs int `mapstructure:"max_delay_ms"` + // CleanupIntervalSeconds: 孤儿锁清理间隔(秒),0 表示禁用 + CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"` +} + +// WaitTimeout 返回等待超时的 time.Duration +func (c *UserMessageQueueConfig) WaitTimeout() time.Duration { + if c.WaitTimeoutMs <= 0 { + return 30 * time.Second + } + return time.Duration(c.WaitTimeoutMs) * time.Millisecond +} + +// GetEffectiveMode 返回生效的模式 +// 注意:Mode 字段已在 load() 中做过白名单校验和规范化,此处无需重复验证 +func (c *UserMessageQueueConfig) GetEffectiveMode() string { + if c.Mode == UMQModeSerialize || c.Mode == UMQModeThrottle { + return c.Mode + } + if c.Enabled { + return UMQModeSerialize // 向后兼容 + } + return "" +} + +// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 +// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。 +type GatewayOpenAIWSConfig struct { + // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为) + ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"` + // IngressModeDefault: ingress 默认模式(off/shared/dedicated) + IngressModeDefault string `mapstructure:"ingress_mode_default"` + // Enabled: 全局总开关(默认 true) + Enabled bool `mapstructure:"enabled"` + // OAuthEnabled: 是否允许 OpenAI OAuth 账号使用 WS + OAuthEnabled bool `mapstructure:"oauth_enabled"` + // APIKeyEnabled: 是否允许 OpenAI API Key 账号使用 WS + APIKeyEnabled bool `mapstructure:"apikey_enabled"` + // ForceHTTP: 全局强制 HTTP(用于紧急回滚) + ForceHTTP bool `mapstructure:"force_http"` + // AllowStoreRecovery: 允许在 WSv2 下按策略恢复 store=true(默认 false) + AllowStoreRecovery bool `mapstructure:"allow_store_recovery"` + // IngressPreviousResponseRecoveryEnabled: ingress 模式收到 previous_response_not_found 时,是否允许自动去掉 previous_response_id 重试一次(默认 true) + IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"` + // StoreDisabledConnMode: store=false 且无可复用会话连接时的建连策略(strict/adaptive/off) + // - strict: 强制新建连接(隔离优先) + // - adaptive: 仅在高风险失败后强制新建连接(性能与隔离折中) + // - off: 不强制新建连接(复用优先) + StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"` + // StoreDisabledForceNewConn: store=false 且无可复用粘连连接时是否强制新建连接(默认 true,保障会话隔离) + // 兼容旧配置;当 StoreDisabledConnMode 为空时才生效。 + StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"` + // PrewarmGenerateEnabled: 是否启用 WSv2 generate=false 预热(默认 false) + PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"` + + // Feature 开关:v2 优先于 v1 + ResponsesWebsockets bool `mapstructure:"responses_websockets"` + ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"` + + // 连接池参数 + MaxConnsPerAccount int `mapstructure:"max_conns_per_account"` + MinIdlePerAccount int `mapstructure:"min_idle_per_account"` + MaxIdlePerAccount int `mapstructure:"max_idle_per_account"` + // DynamicMaxConnsByAccountConcurrencyEnabled: 是否按账号并发动态计算连接池上限 + DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"` + // OAuthMaxConnsFactor: OAuth 账号连接池系数(effective=ceil(concurrency*factor)) + OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"` + // APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor)) + APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"` + DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` + ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` + WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` + PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"` + QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"` + // EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数) + EventFlushBatchSize int `mapstructure:"event_flush_batch_size"` + // EventFlushIntervalMS: WS 流式写出最大等待时间(毫秒);0 表示仅按 batch 触发 + EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"` + // PrewarmCooldownMS: 连接池预热触发冷却时间(毫秒) + PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"` + // FallbackCooldownSeconds: WS 回退冷却窗口,避免 WS/HTTP 抖动;0 表示关闭冷却 + FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"` + // RetryBackoffInitialMS: WS 重试初始退避(毫秒);<=0 表示关闭退避 + RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"` + // RetryBackoffMaxMS: WS 重试最大退避(毫秒) + RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"` + // RetryJitterRatio: WS 重试退避抖动比例(0-1) + RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"` + // RetryTotalBudgetMS: WS 单次请求重试总预算(毫秒);0 表示关闭预算限制 + RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"` + // PayloadLogSampleRate: payload_schema 日志采样率(0-1) + PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"` + + // 账号调度与粘连参数 + LBTopK int `mapstructure:"lb_top_k"` + // StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL + StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"` + // SessionHashReadOldFallback: 会话哈希迁移期是否允许“新 key 未命中时回退读旧 SHA-256 key” + SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"` + // SessionHashDualWriteOld: 会话哈希迁移期是否双写旧 SHA-256 key(短 TTL) + SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"` + // MetadataBridgeEnabled: RequestMetadata 迁移期是否保留旧 ctxkey.* 兼容桥接 + MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"` + // StickyResponseIDTTLSeconds: response_id -> account_id 粘连 TTL + StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"` + // StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退) + StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"` + + SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"` +} + +// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。 +type GatewayOpenAIWSSchedulerScoreWeights struct { + Priority float64 `mapstructure:"priority"` + Load float64 `mapstructure:"load"` + Queue float64 `mapstructure:"queue"` + ErrorRate float64 `mapstructure:"error_rate"` + TTFT float64 `mapstructure:"ttft"` +} + +// GatewayUsageRecordConfig 使用量记录异步队列配置 +type GatewayUsageRecordConfig struct { + // WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限) + WorkerCount int `mapstructure:"worker_count"` + // QueueSize: 队列容量(有界) + QueueSize int `mapstructure:"queue_size"` + // TaskTimeoutSeconds: 单个使用量记录任务超时(秒) + TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"` + // OverflowPolicy: 队列满时策略(drop/sample/sync) + OverflowPolicy string `mapstructure:"overflow_policy"` + // OverflowSamplePercent: sample 策略下,同步回写采样百分比(1-100) + OverflowSamplePercent int `mapstructure:"overflow_sample_percent"` + + // AutoScaleEnabled: 是否启用 worker 自动扩缩容 + AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"` + // AutoScaleMinWorkers: 自动扩缩容最小 worker 数 + AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"` + // AutoScaleMaxWorkers: 自动扩缩容最大 worker 数 + AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"` + // AutoScaleUpQueuePercent: 队列占用率达到该阈值时触发扩容 + AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"` + // AutoScaleDownQueuePercent: 队列占用率低于该阈值时触发缩容 + AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"` + // AutoScaleUpStep: 每次扩容步长 + AutoScaleUpStep int `mapstructure:"auto_scale_up_step"` + // AutoScaleDownStep: 每次缩容步长 + AutoScaleDownStep int `mapstructure:"auto_scale_down_step"` + // AutoScaleCheckIntervalSeconds: 自动扩缩容检测间隔(秒) + AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"` + // AutoScaleCooldownSeconds: 自动扩缩容冷却时间(秒) + AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"` +} + +// SoraModelFiltersConfig Sora 模型过滤配置 +type SoraModelFiltersConfig struct { + // HidePromptEnhance 是否隐藏 prompt-enhance 模型 + HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"` } // TLSFingerprintConfig TLS指纹伪装配置 @@ -479,8 +838,9 @@ type OpsMetricsCollectorCacheConfig struct { type JWTConfig struct { Secret string `mapstructure:"secret"` ExpireHour int `mapstructure:"expire_hour"` - // AccessTokenExpireMinutes: Access Token有效期(分钟),默认15分钟 - // 短有效期减少被盗用风险,配合Refresh Token实现无感续期 + // AccessTokenExpireMinutes: Access Token有效期(分钟) + // - >0: 使用分钟配置(优先级高于 ExpireHour) + // - =0: 回退使用 ExpireHour(向后兼容旧配置) AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"` // RefreshTokenExpireDays: Refresh Token有效期(天),默认30天 RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"` @@ -525,6 +885,20 @@ type APIKeyAuthCacheConfig struct { Singleflight bool `mapstructure:"singleflight"` } +// SubscriptionCacheConfig 订阅认证 L1 缓存配置 +type SubscriptionCacheConfig struct { + L1Size int `mapstructure:"l1_size"` + L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` + JitterPercent int `mapstructure:"jitter_percent"` +} + +// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置。 +// 用于将“请求路径触发的维护动作”有界化,避免高并发下 goroutine 膨胀。 +type SubscriptionMaintenanceConfig struct { + WorkerCount int `mapstructure:"worker_count"` + QueueSize int `mapstructure:"queue_size"` +} + // DashboardCacheConfig 仪表盘统计缓存配置 type DashboardCacheConfig struct { // Enabled: 是否启用仪表盘缓存 @@ -588,7 +962,19 @@ func NormalizeRunMode(value string) string { } } +// Load 读取并校验完整配置(要求 jwt.secret 已显式提供)。 func Load() (*Config, error) { + return load(false) +} + +// LoadForBootstrap 读取启动阶段配置。 +// +// 启动阶段允许 jwt.secret 先留空,后续由数据库初始化流程补齐并再次完整校验。 +func LoadForBootstrap() (*Config, error) { + return load(true) +} + +func load(allowMissingJWTSecret bool) (*Config, error) { viper.SetConfigName("config") viper.SetConfigType("yaml") @@ -630,6 +1016,7 @@ func Load() (*Config, error) { if cfg.Server.Mode == "" { cfg.Server.Mode = "debug" } + cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL) cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) @@ -648,14 +1035,25 @@ func Load() (*Config, error) { cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy) + cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level)) + cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format)) + cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName) + cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment) + cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel)) + cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath) - if cfg.JWT.Secret == "" { - secret, err := generateJWTSecret(64) - if err != nil { - return nil, fmt.Errorf("generate jwt secret error: %w", err) - } - cfg.JWT.Secret = secret - log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.") + // 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。 + // 新键未配置(<=0)时回退旧键;新键优先。 + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds + } + + // Normalize UMQ mode: 白名单校验,非法值在加载时一次性 warn 并清空 + if m := cfg.Gateway.UserMessageQueue.Mode; m != "" && m != UMQModeSerialize && m != UMQModeThrottle { + slog.Warn("invalid user_message_queue mode, disabling", + "mode", m, + "valid_modes", []string{UMQModeSerialize, UMQModeThrottle}) + cfg.Gateway.UserMessageQueue.Mode = "" } // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256) @@ -667,29 +1065,39 @@ func Load() (*Config, error) { } cfg.Totp.EncryptionKey = key cfg.Totp.EncryptionKeyConfigured = false - log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.") + slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.") } else { cfg.Totp.EncryptionKeyConfigured = true } + originalJWTSecret := cfg.JWT.Secret + if allowMissingJWTSecret && originalJWTSecret == "" { + // 启动阶段允许先无 JWT 密钥,后续在数据库初始化后补齐。 + cfg.JWT.Secret = strings.Repeat("0", 32) + } + if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("validate config error: %w", err) } + if allowMissingJWTSecret && originalJWTSecret == "" { + cfg.JWT.Secret = "" + } + if !cfg.Security.URLAllowlist.Enabled { - log.Println("Warning: security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).") + slog.Warn("security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).") } if !cfg.Security.ResponseHeaders.Enabled { - log.Println("Warning: security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).") + slog.Warn("security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).") } if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) { - log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.") + slog.Warn("JWT secret appears weak; use a 32+ character random secret in production.") } if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 { - log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v", - cfg.Security.ResponseHeaders.AdditionalAllowed, - cfg.Security.ResponseHeaders.ForceRemove, + slog.Info("response header policy configured", + "additional_allowed", cfg.Security.ResponseHeaders.AdditionalAllowed, + "force_remove", cfg.Security.ResponseHeaders.ForceRemove, ) } @@ -702,11 +1110,12 @@ func setDefaults() { // Server viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.port", 8080) - viper.SetDefault("server.mode", "debug") + viper.SetDefault("server.mode", "release") + viper.SetDefault("server.frontend_url", "") viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.trusted_proxies", []string{}) - viper.SetDefault("server.max_request_body_size", int64(100*1024*1024)) + viper.SetDefault("server.max_request_body_size", int64(256*1024*1024)) // H2C 默认配置 viper.SetDefault("server.h2c.enabled", false) viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流 @@ -715,6 +1124,25 @@ func setDefaults() { viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB + // Log + viper.SetDefault("log.level", "info") + viper.SetDefault("log.format", "console") + viper.SetDefault("log.service_name", "sub2api") + viper.SetDefault("log.env", "production") + viper.SetDefault("log.caller", true) + viper.SetDefault("log.stacktrace_level", "error") + viper.SetDefault("log.output.to_stdout", true) + viper.SetDefault("log.output.to_file", true) + viper.SetDefault("log.output.file_path", "") + viper.SetDefault("log.rotation.max_size_mb", 100) + viper.SetDefault("log.rotation.max_backups", 10) + viper.SetDefault("log.rotation.max_age_days", 7) + viper.SetDefault("log.rotation.compress", true) + viper.SetDefault("log.rotation.local_time", true) + viper.SetDefault("log.sampling.enabled", false) + viper.SetDefault("log.sampling.initial", 100) + viper.SetDefault("log.sampling.thereafter", 100) + // CORS viper.SetDefault("cors.allowed_origins", []string{}) viper.SetDefault("cors.allow_credentials", true) @@ -737,13 +1165,16 @@ func setDefaults() { viper.SetDefault("security.url_allowlist.crs_hosts", []string{}) viper.SetDefault("security.url_allowlist.allow_private_hosts", true) viper.SetDefault("security.url_allowlist.allow_insecure_http", true) - viper.SetDefault("security.response_headers.enabled", false) + viper.SetDefault("security.response_headers.enabled", true) viper.SetDefault("security.response_headers.additional_allowed", []string{}) viper.SetDefault("security.response_headers.force_remove", []string{}) viper.SetDefault("security.csp.enabled", true) viper.SetDefault("security.csp.policy", DefaultCSPPolicy) viper.SetDefault("security.proxy_probe.insecure_skip_verify", false) + // Security - disable direct fallback on proxy error + viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) + // Billing viper.SetDefault("billing.circuit_breaker.enabled", true) viper.SetDefault("billing.circuit_breaker.failure_threshold", 5) @@ -775,9 +1206,9 @@ func setDefaults() { viper.SetDefault("database.user", "postgres") viper.SetDefault("database.password", "postgres") viper.SetDefault("database.dbname", "sub2api") - viper.SetDefault("database.sslmode", "disable") - viper.SetDefault("database.max_open_conns", 50) - viper.SetDefault("database.max_idle_conns", 10) + viper.SetDefault("database.sslmode", "prefer") + viper.SetDefault("database.max_open_conns", 256) + viper.SetDefault("database.max_idle_conns", 128) viper.SetDefault("database.conn_max_lifetime_minutes", 30) viper.SetDefault("database.conn_max_idle_time_minutes", 5) @@ -789,8 +1220,8 @@ func setDefaults() { viper.SetDefault("redis.dial_timeout_seconds", 5) viper.SetDefault("redis.read_timeout_seconds", 3) viper.SetDefault("redis.write_timeout_seconds", 3) - viper.SetDefault("redis.pool_size", 128) - viper.SetDefault("redis.min_idle_conns", 10) + viper.SetDefault("redis.pool_size", 1024) + viper.SetDefault("redis.min_idle_conns", 128) viper.SetDefault("redis.enable_tls", false) // Ops (vNext) @@ -810,9 +1241,9 @@ func setDefaults() { // JWT viper.SetDefault("jwt.secret", "") viper.SetDefault("jwt.expire_hour", 24) - viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期 - viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期 - viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新 + viper.SetDefault("jwt.access_token_expire_minutes", 0) // 0 表示回退到 expire_hour + viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期 + viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新 // TOTP viper.SetDefault("totp.encryption_key", "") @@ -830,9 +1261,9 @@ func setDefaults() { // RateLimit viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) - // Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查 - viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json") - viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.sha256") + // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移) + viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json") + viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256") viper.SetDefault("pricing.data_dir", "./data") viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json") viper.SetDefault("pricing.update_interval_hours", 24) @@ -849,6 +1280,11 @@ func setDefaults() { viper.SetDefault("api_key_auth_cache.jitter_percent", 10) viper.SetDefault("api_key_auth_cache.singleflight", true) + // Subscription auth L1 cache + viper.SetDefault("subscription_cache.l1_size", 16384) + viper.SetDefault("subscription_cache.l1_ttl_seconds", 10) + viper.SetDefault("subscription_cache.jitter_percent", 10) + // Dashboard cache viper.SetDefault("dashboard_cache.enabled", true) viper.SetDefault("dashboard_cache.key_prefix", "sub2api:") @@ -874,6 +1310,16 @@ func setDefaults() { viper.SetDefault("usage_cleanup.worker_interval_seconds", 10) viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800) + // Idempotency + viper.SetDefault("idempotency.observe_only", true) + viper.SetDefault("idempotency.default_ttl_seconds", 86400) + viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600) + viper.SetDefault("idempotency.processing_timeout_seconds", 30) + viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5) + viper.SetDefault("idempotency.max_stored_response_len", 64*1024) + viper.SetDefault("idempotency.cleanup_interval_seconds", 60) + viper.SetDefault("idempotency.cleanup_batch_size", 500) + // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", true) @@ -882,13 +1328,72 @@ func setDefaults() { viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.max_account_switches", 10) viper.SetDefault("gateway.max_account_switches_gemini", 3) + viper.SetDefault("gateway.force_codex_cli", false) + viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) + // OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚) + viper.SetDefault("gateway.openai_ws.enabled", true) + viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false) + viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared") + viper.SetDefault("gateway.openai_ws.oauth_enabled", true) + viper.SetDefault("gateway.openai_ws.apikey_enabled", true) + viper.SetDefault("gateway.openai_ws.force_http", false) + viper.SetDefault("gateway.openai_ws.allow_store_recovery", false) + viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true) + viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict") + viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true) + viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false) + viper.SetDefault("gateway.openai_ws.responses_websockets", false) + viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true) + viper.SetDefault("gateway.openai_ws.max_conns_per_account", 128) + viper.SetDefault("gateway.openai_ws.min_idle_per_account", 4) + viper.SetDefault("gateway.openai_ws.max_idle_per_account", 12) + viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true) + viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0) + viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0) + viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10) + viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900) + viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120) + viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7) + viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64) + viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1) + viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10) + viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300) + viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30) + viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120) + viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000) + viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2) + viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000) + viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2) + viper.SetDefault("gateway.openai_ws.lb_top_k", 7) + viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true) + viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true) + viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true) + viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) - viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) + viper.SetDefault("gateway.antigravity_extra_retries", 10) + viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) + viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) + viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) + viper.SetDefault("gateway.gemini_debug_response_headers", false) + viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024)) + viper.SetDefault("gateway.sora_stream_timeout_seconds", 900) + viper.SetDefault("gateway.sora_request_timeout_seconds", 180) + viper.SetDefault("gateway.sora_stream_mode", "force") + viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true) + viper.SetDefault("gateway.sora_media_require_api_key", true) + viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) - viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) + viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大) viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) - viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认) + viper.SetDefault("gateway.max_conns_per_host", 1024) // 每主机最大连接数(含活跃;流式/HTTP1.1 场景可调大,如 2400+) viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒) viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.client_idle_ttl_seconds", 900) @@ -912,16 +1417,73 @@ func setDefaults() { viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3) viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000) viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300) + viper.SetDefault("gateway.usage_record.worker_count", 128) + viper.SetDefault("gateway.usage_record.queue_size", 16384) + viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5) + viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample) + viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10) + viper.SetDefault("gateway.usage_record.auto_scale_enabled", true) + viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128) + viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512) + viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70) + viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15) + viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32) + viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16) + viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3) + viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10) + viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30) + viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15) // TLS指纹伪装配置(默认关闭,需要账号级别单独启用) + // 用户消息串行队列默认值 + viper.SetDefault("gateway.user_message_queue.enabled", false) + viper.SetDefault("gateway.user_message_queue.lock_ttl_ms", 120000) + viper.SetDefault("gateway.user_message_queue.wait_timeout_ms", 30000) + viper.SetDefault("gateway.user_message_queue.min_delay_ms", 200) + viper.SetDefault("gateway.user_message_queue.max_delay_ms", 2000) + viper.SetDefault("gateway.user_message_queue.cleanup_interval_seconds", 60) + viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("concurrency.ping_interval", 10) + // Sora 直连配置 + viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend") + viper.SetDefault("sora.client.timeout_seconds", 120) + viper.SetDefault("sora.client.max_retries", 3) + viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900) + viper.SetDefault("sora.client.poll_interval_seconds", 2) + viper.SetDefault("sora.client.max_poll_attempts", 600) + viper.SetDefault("sora.client.recent_task_limit", 50) + viper.SetDefault("sora.client.recent_task_limit_max", 200) + viper.SetDefault("sora.client.debug", false) + viper.SetDefault("sora.client.use_openai_token_provider", false) + viper.SetDefault("sora.client.headers", map[string]string{}) + viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + viper.SetDefault("sora.client.disable_tls_fingerprint", false) + viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true) + viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080") + viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131") + viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60) + viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true) + viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600) + + viper.SetDefault("sora.storage.type", "local") + viper.SetDefault("sora.storage.local_path", "") + viper.SetDefault("sora.storage.fallback_to_upstream", true) + viper.SetDefault("sora.storage.max_concurrent_downloads", 4) + viper.SetDefault("sora.storage.download_timeout_seconds", 120) + viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20)) + viper.SetDefault("sora.storage.debug", false) + viper.SetDefault("sora.storage.cleanup.enabled", true) + viper.SetDefault("sora.storage.cleanup.retention_days", 7) + viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *") + // TokenRefresh viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token) viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 + viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token // Gemini OAuth - configure via environment variables or config file // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET @@ -930,9 +1492,103 @@ func setDefaults() { viper.SetDefault("gemini.oauth.client_secret", "") viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.quota.policy", "") + + // Subscription Maintenance (bounded queue + worker pool) + viper.SetDefault("subscription_maintenance.worker_count", 2) + viper.SetDefault("subscription_maintenance.queue_size", 1024) + } func (c *Config) Validate() error { + jwtSecret := strings.TrimSpace(c.JWT.Secret) + if jwtSecret == "" { + return fmt.Errorf("jwt.secret is required") + } + // NOTE: 按 UTF-8 编码后的字节长度计算。 + // 选择 bytes 而不是 rune 计数,确保二进制/随机串的长度语义更接近“熵”而非“字符数”。 + if len([]byte(jwtSecret)) < 32 { + return fmt.Errorf("jwt.secret must be at least 32 bytes") + } + switch c.Log.Level { + case "debug", "info", "warn", "error": + case "": + return fmt.Errorf("log.level is required") + default: + return fmt.Errorf("log.level must be one of: debug/info/warn/error") + } + switch c.Log.Format { + case "json", "console": + case "": + return fmt.Errorf("log.format is required") + default: + return fmt.Errorf("log.format must be one of: json/console") + } + switch c.Log.StacktraceLevel { + case "none", "error", "fatal": + case "": + return fmt.Errorf("log.stacktrace_level is required") + default: + return fmt.Errorf("log.stacktrace_level must be one of: none/error/fatal") + } + if !c.Log.Output.ToStdout && !c.Log.Output.ToFile { + return fmt.Errorf("log.output.to_stdout and log.output.to_file cannot both be false") + } + if c.Log.Rotation.MaxSizeMB <= 0 { + return fmt.Errorf("log.rotation.max_size_mb must be positive") + } + if c.Log.Rotation.MaxBackups < 0 { + return fmt.Errorf("log.rotation.max_backups must be non-negative") + } + if c.Log.Rotation.MaxAgeDays < 0 { + return fmt.Errorf("log.rotation.max_age_days must be non-negative") + } + if c.Log.Sampling.Enabled { + if c.Log.Sampling.Initial <= 0 { + return fmt.Errorf("log.sampling.initial must be positive when sampling is enabled") + } + if c.Log.Sampling.Thereafter <= 0 { + return fmt.Errorf("log.sampling.thereafter must be positive when sampling is enabled") + } + } else { + if c.Log.Sampling.Initial < 0 { + return fmt.Errorf("log.sampling.initial must be non-negative") + } + if c.Log.Sampling.Thereafter < 0 { + return fmt.Errorf("log.sampling.thereafter must be non-negative") + } + } + + if c.SubscriptionMaintenance.WorkerCount < 0 { + return fmt.Errorf("subscription_maintenance.worker_count must be non-negative") + } + if c.SubscriptionMaintenance.QueueSize < 0 { + return fmt.Errorf("subscription_maintenance.queue_size must be non-negative") + } + + // Gemini OAuth 配置校验:client_id 与 client_secret 必须同时设置或同时留空。 + // 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。 + geminiClientID := strings.TrimSpace(c.Gemini.OAuth.ClientID) + geminiClientSecret := strings.TrimSpace(c.Gemini.OAuth.ClientSecret) + if (geminiClientID == "") != (geminiClientSecret == "") { + return fmt.Errorf("gemini.oauth.client_id and gemini.oauth.client_secret must be both set or both empty") + } + + if strings.TrimSpace(c.Server.FrontendURL) != "" { + if err := ValidateAbsoluteHTTPURL(c.Server.FrontendURL); err != nil { + return fmt.Errorf("server.frontend_url invalid: %w", err) + } + u, err := url.Parse(strings.TrimSpace(c.Server.FrontendURL)) + if err != nil { + return fmt.Errorf("server.frontend_url invalid: %w", err) + } + if u.RawQuery != "" || u.ForceQuery { + return fmt.Errorf("server.frontend_url invalid: must not include query") + } + if u.User != nil { + return fmt.Errorf("server.frontend_url invalid: must not include userinfo") + } + warnIfInsecureURL("server.frontend_url", c.Server.FrontendURL) + } if c.JWT.ExpireHour <= 0 { return fmt.Errorf("jwt.expire_hour must be positive") } @@ -940,20 +1596,20 @@ func (c *Config) Validate() error { return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)") } if c.JWT.ExpireHour > 24 { - log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour) + slog.Warn("jwt.expire_hour is high; consider shorter expiration for security", "expire_hour", c.JWT.ExpireHour) } // JWT Refresh Token配置验证 - if c.JWT.AccessTokenExpireMinutes <= 0 { - return fmt.Errorf("jwt.access_token_expire_minutes must be positive") + if c.JWT.AccessTokenExpireMinutes < 0 { + return fmt.Errorf("jwt.access_token_expire_minutes must be non-negative") } if c.JWT.AccessTokenExpireMinutes > 720 { - log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes) + slog.Warn("jwt.access_token_expire_minutes is high; consider shorter expiration for security", "access_token_expire_minutes", c.JWT.AccessTokenExpireMinutes) } if c.JWT.RefreshTokenExpireDays <= 0 { return fmt.Errorf("jwt.refresh_token_expire_days must be positive") } if c.JWT.RefreshTokenExpireDays > 90 { - log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays) + slog.Warn("jwt.refresh_token_expire_days is high; consider shorter expiration for security", "refresh_token_expire_days", c.JWT.RefreshTokenExpireDays) } if c.JWT.RefreshWindowMinutes < 0 { return fmt.Errorf("jwt.refresh_window_minutes must be non-negative") @@ -1159,9 +1815,116 @@ func (c *Config) Validate() error { return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative") } } + if c.Idempotency.DefaultTTLSeconds <= 0 { + return fmt.Errorf("idempotency.default_ttl_seconds must be positive") + } + if c.Idempotency.SystemOperationTTLSeconds <= 0 { + return fmt.Errorf("idempotency.system_operation_ttl_seconds must be positive") + } + if c.Idempotency.ProcessingTimeoutSeconds <= 0 { + return fmt.Errorf("idempotency.processing_timeout_seconds must be positive") + } + if c.Idempotency.FailedRetryBackoffSeconds <= 0 { + return fmt.Errorf("idempotency.failed_retry_backoff_seconds must be positive") + } + if c.Idempotency.MaxStoredResponseLen <= 0 { + return fmt.Errorf("idempotency.max_stored_response_len must be positive") + } + if c.Idempotency.CleanupIntervalSeconds <= 0 { + return fmt.Errorf("idempotency.cleanup_interval_seconds must be positive") + } + if c.Idempotency.CleanupBatchSize <= 0 { + return fmt.Errorf("idempotency.cleanup_batch_size must be positive") + } if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } + if c.Gateway.UpstreamResponseReadMaxBytes <= 0 { + return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive") + } + if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 { + return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive") + } + if c.Gateway.SoraMaxBodySize < 0 { + return fmt.Errorf("gateway.sora_max_body_size must be non-negative") + } + if c.Gateway.SoraStreamTimeoutSeconds < 0 { + return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative") + } + if c.Gateway.SoraRequestTimeoutSeconds < 0 { + return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative") + } + if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 { + return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative") + } + if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" { + switch mode { + case "force", "error": + default: + return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error") + } + } + if c.Sora.Client.TimeoutSeconds < 0 { + return fmt.Errorf("sora.client.timeout_seconds must be non-negative") + } + if c.Sora.Client.MaxRetries < 0 { + return fmt.Errorf("sora.client.max_retries must be non-negative") + } + if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 { + return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") + } + if c.Sora.Client.PollIntervalSeconds < 0 { + return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative") + } + if c.Sora.Client.MaxPollAttempts < 0 { + return fmt.Errorf("sora.client.max_poll_attempts must be non-negative") + } + if c.Sora.Client.RecentTaskLimit < 0 { + return fmt.Errorf("sora.client.recent_task_limit must be non-negative") + } + if c.Sora.Client.RecentTaskLimitMax < 0 { + return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative") + } + if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 && + c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit { + c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit + } + if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 { + return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative") + } + if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 { + return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") + } + if !c.Sora.Client.CurlCFFISidecar.Enabled { + return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true") + } + if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" { + return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required") + } + if c.Sora.Storage.MaxConcurrentDownloads < 0 { + return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative") + } + if c.Sora.Storage.DownloadTimeoutSeconds < 0 { + return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative") + } + if c.Sora.Storage.MaxDownloadBytes < 0 { + return fmt.Errorf("sora.storage.max_download_bytes must be non-negative") + } + if c.Sora.Storage.Cleanup.Enabled { + if c.Sora.Storage.Cleanup.RetentionDays <= 0 { + return fmt.Errorf("sora.storage.cleanup.retention_days must be positive") + } + if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" { + return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled") + } + } else { + if c.Sora.Storage.Cleanup.RetentionDays < 0 { + return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative") + } + } + if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" { + return fmt.Errorf("sora.storage.type must be 'local'") + } if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { switch c.Gateway.ConnectionPoolIsolation { case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: @@ -1183,7 +1946,7 @@ func (c *Config) Validate() error { return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive") } if c.Gateway.IdleConnTimeoutSeconds > 180 { - log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds) + slog.Warn("gateway.idle_conn_timeout_seconds is high; consider 60-120 seconds for better connection reuse", "idle_conn_timeout_seconds", c.Gateway.IdleConnTimeoutSeconds) } if c.Gateway.MaxUpstreamClients <= 0 { return fmt.Errorf("gateway.max_upstream_clients must be positive") @@ -1208,12 +1971,188 @@ func (c *Config) Validate() error { (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) { return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds") } + // 兼容旧键 sticky_previous_response_ttl_seconds + if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { + c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds + } + if c.Gateway.OpenAIWS.MaxConnsPerAccount <= 0 { + return fmt.Errorf("gateway.openai_ws.max_conns_per_account must be positive") + } + if c.Gateway.OpenAIWS.MinIdlePerAccount < 0 { + return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be non-negative") + } + if c.Gateway.OpenAIWS.MaxIdlePerAccount < 0 { + return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be non-negative") + } + if c.Gateway.OpenAIWS.MinIdlePerAccount > c.Gateway.OpenAIWS.MaxIdlePerAccount { + return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account") + } + if c.Gateway.OpenAIWS.MaxIdlePerAccount > c.Gateway.OpenAIWS.MaxConnsPerAccount { + return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account") + } + if c.Gateway.OpenAIWS.OAuthMaxConnsFactor <= 0 { + return fmt.Errorf("gateway.openai_ws.oauth_max_conns_factor must be positive") + } + if c.Gateway.OpenAIWS.APIKeyMaxConnsFactor <= 0 { + return fmt.Errorf("gateway.openai_ws.apikey_max_conns_factor must be positive") + } + if c.Gateway.OpenAIWS.DialTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.dial_timeout_seconds must be positive") + } + if c.Gateway.OpenAIWS.ReadTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.read_timeout_seconds must be positive") + } + if c.Gateway.OpenAIWS.WriteTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.write_timeout_seconds must be positive") + } + if c.Gateway.OpenAIWS.PoolTargetUtilization <= 0 || c.Gateway.OpenAIWS.PoolTargetUtilization > 1 { + return fmt.Errorf("gateway.openai_ws.pool_target_utilization must be within (0,1]") + } + if c.Gateway.OpenAIWS.QueueLimitPerConn <= 0 { + return fmt.Errorf("gateway.openai_ws.queue_limit_per_conn must be positive") + } + if c.Gateway.OpenAIWS.EventFlushBatchSize <= 0 { + return fmt.Errorf("gateway.openai_ws.event_flush_batch_size must be positive") + } + if c.Gateway.OpenAIWS.EventFlushIntervalMS < 0 { + return fmt.Errorf("gateway.openai_ws.event_flush_interval_ms must be non-negative") + } + if c.Gateway.OpenAIWS.PrewarmCooldownMS < 0 { + return fmt.Errorf("gateway.openai_ws.prewarm_cooldown_ms must be non-negative") + } + if c.Gateway.OpenAIWS.FallbackCooldownSeconds < 0 { + return fmt.Errorf("gateway.openai_ws.fallback_cooldown_seconds must be non-negative") + } + if c.Gateway.OpenAIWS.RetryBackoffInitialMS < 0 { + return fmt.Errorf("gateway.openai_ws.retry_backoff_initial_ms must be non-negative") + } + if c.Gateway.OpenAIWS.RetryBackoffMaxMS < 0 { + return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be non-negative") + } + if c.Gateway.OpenAIWS.RetryBackoffInitialMS > 0 && c.Gateway.OpenAIWS.RetryBackoffMaxMS > 0 && + c.Gateway.OpenAIWS.RetryBackoffMaxMS < c.Gateway.OpenAIWS.RetryBackoffInitialMS { + return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be >= retry_backoff_initial_ms") + } + if c.Gateway.OpenAIWS.RetryJitterRatio < 0 || c.Gateway.OpenAIWS.RetryJitterRatio > 1 { + return fmt.Errorf("gateway.openai_ws.retry_jitter_ratio must be within [0,1]") + } + if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 { + return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative") + } + if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" { + switch mode { + case "off", "shared", "dedicated": + default: + return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated") + } + } + if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" { + switch mode { + case "strict", "adaptive", "off": + default: + return fmt.Errorf("gateway.openai_ws.store_disabled_conn_mode must be one of strict|adaptive|off") + } + } + if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 { + return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]") + } + if c.Gateway.OpenAIWS.LBTopK <= 0 { + return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive") + } + if c.Gateway.OpenAIWS.StickySessionTTLSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.sticky_session_ttl_seconds must be positive") + } + if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.sticky_response_id_ttl_seconds must be positive") + } + if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 { + return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative") + } + if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative") + } + weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority + + c.Gateway.OpenAIWS.SchedulerScoreWeights.Load + + c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue + + c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate + + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT + if weightSum <= 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero") + } if c.Gateway.MaxLineSize < 0 { return fmt.Errorf("gateway.max_line_size must be non-negative") } if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 { return fmt.Errorf("gateway.max_line_size must be at least 1MB") } + if c.Gateway.UsageRecord.WorkerCount <= 0 { + return fmt.Errorf("gateway.usage_record.worker_count must be positive") + } + if c.Gateway.UsageRecord.QueueSize <= 0 { + return fmt.Errorf("gateway.usage_record.queue_size must be positive") + } + if c.Gateway.UsageRecord.TaskTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.usage_record.task_timeout_seconds must be positive") + } + switch strings.ToLower(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy)) { + case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync: + default: + return fmt.Errorf("gateway.usage_record.overflow_policy must be one of: %s/%s/%s", + UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync) + } + if c.Gateway.UsageRecord.OverflowSamplePercent < 0 || c.Gateway.UsageRecord.OverflowSamplePercent > 100 { + return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be between 0-100") + } + if strings.EqualFold(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy), UsageRecordOverflowPolicySample) && + c.Gateway.UsageRecord.OverflowSamplePercent <= 0 { + return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be positive when overflow_policy=sample") + } + if c.Gateway.UsageRecord.AutoScaleEnabled { + if c.Gateway.UsageRecord.AutoScaleMinWorkers <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_min_workers must be positive") + } + if c.Gateway.UsageRecord.AutoScaleMaxWorkers <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be positive") + } + if c.Gateway.UsageRecord.AutoScaleMaxWorkers < c.Gateway.UsageRecord.AutoScaleMinWorkers { + return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be >= auto_scale_min_workers") + } + if c.Gateway.UsageRecord.WorkerCount < c.Gateway.UsageRecord.AutoScaleMinWorkers || + c.Gateway.UsageRecord.WorkerCount > c.Gateway.UsageRecord.AutoScaleMaxWorkers { + return fmt.Errorf("gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers") + } + if c.Gateway.UsageRecord.AutoScaleUpQueuePercent <= 0 || c.Gateway.UsageRecord.AutoScaleUpQueuePercent > 100 { + return fmt.Errorf("gateway.usage_record.auto_scale_up_queue_percent must be between 1-100") + } + if c.Gateway.UsageRecord.AutoScaleDownQueuePercent < 0 || c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 100 { + return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be between 0-99") + } + if c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= c.Gateway.UsageRecord.AutoScaleUpQueuePercent { + return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be less than auto_scale_up_queue_percent") + } + if c.Gateway.UsageRecord.AutoScaleUpStep <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_up_step must be positive") + } + if c.Gateway.UsageRecord.AutoScaleDownStep <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_down_step must be positive") + } + if c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_check_interval_seconds must be positive") + } + if c.Gateway.UsageRecord.AutoScaleCooldownSeconds < 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative") + } + } + if c.Gateway.UserGroupRateCacheTTLSeconds <= 0 { + return fmt.Errorf("gateway.user_group_rate_cache_ttl_seconds must be positive") + } + if c.Gateway.ModelsListCacheTTLSeconds < 10 || c.Gateway.ModelsListCacheTTLSeconds > 30 { + return fmt.Errorf("gateway.models_list_cache_ttl_seconds must be between 10-30") + } if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") } @@ -1420,6 +2359,6 @@ func warnIfInsecureURL(field, raw string) { return } if strings.EqualFold(u.Scheme, "http") { - log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field) + slog.Warn("url uses http scheme; use https in production to avoid token leakage", "field", field) } } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index f734619f..e3b592e2 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -6,8 +6,28 @@ import ( "time" "github.com/spf13/viper" + "github.com/stretchr/testify/require" ) +func resetViperWithJWTSecret(t *testing.T) { + t.Helper() + viper.Reset() + t.Setenv("JWT_SECRET", strings.Repeat("x", 32)) +} + +func TestLoadForBootstrapAllowsMissingJWTSecret(t *testing.T) { + viper.Reset() + t.Setenv("JWT_SECRET", "") + + cfg, err := LoadForBootstrap() + if err != nil { + t.Fatalf("LoadForBootstrap() error: %v", err) + } + if cfg.JWT.Secret != "" { + t.Fatalf("LoadForBootstrap() should keep empty jwt.secret during bootstrap") + } +} + func TestNormalizeRunMode(t *testing.T) { tests := []struct { input string @@ -29,7 +49,7 @@ func TestNormalizeRunMode(t *testing.T) { } func TestLoadDefaultSchedulingConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -56,8 +76,141 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) { } } +func TestLoadDefaultOpenAIWSConfig(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Gateway.OpenAIWS.Enabled { + t.Fatalf("Gateway.OpenAIWS.Enabled = false, want true") + } + if !cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 { + t.Fatalf("Gateway.OpenAIWS.ResponsesWebsocketsV2 = false, want true") + } + if cfg.Gateway.OpenAIWS.ResponsesWebsockets { + t.Fatalf("Gateway.OpenAIWS.ResponsesWebsockets = true, want false") + } + if !cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled { + t.Fatalf("Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false, want true") + } + if cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor != 1.0 { + t.Fatalf("Gateway.OpenAIWS.OAuthMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor) + } + if cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor != 1.0 { + t.Fatalf("Gateway.OpenAIWS.APIKeyMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor) + } + if cfg.Gateway.OpenAIWS.StickySessionTTLSeconds != 3600 { + t.Fatalf("Gateway.OpenAIWS.StickySessionTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) + } + if !cfg.Gateway.OpenAIWS.SessionHashReadOldFallback { + t.Fatalf("Gateway.OpenAIWS.SessionHashReadOldFallback = false, want true") + } + if !cfg.Gateway.OpenAIWS.SessionHashDualWriteOld { + t.Fatalf("Gateway.OpenAIWS.SessionHashDualWriteOld = false, want true") + } + if !cfg.Gateway.OpenAIWS.MetadataBridgeEnabled { + t.Fatalf("Gateway.OpenAIWS.MetadataBridgeEnabled = false, want true") + } + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 3600 { + t.Fatalf("Gateway.OpenAIWS.StickyResponseIDTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds) + } + if cfg.Gateway.OpenAIWS.FallbackCooldownSeconds != 30 { + t.Fatalf("Gateway.OpenAIWS.FallbackCooldownSeconds = %d, want 30", cfg.Gateway.OpenAIWS.FallbackCooldownSeconds) + } + if cfg.Gateway.OpenAIWS.EventFlushBatchSize != 1 { + t.Fatalf("Gateway.OpenAIWS.EventFlushBatchSize = %d, want 1", cfg.Gateway.OpenAIWS.EventFlushBatchSize) + } + if cfg.Gateway.OpenAIWS.EventFlushIntervalMS != 10 { + t.Fatalf("Gateway.OpenAIWS.EventFlushIntervalMS = %d, want 10", cfg.Gateway.OpenAIWS.EventFlushIntervalMS) + } + if cfg.Gateway.OpenAIWS.PrewarmCooldownMS != 300 { + t.Fatalf("Gateway.OpenAIWS.PrewarmCooldownMS = %d, want 300", cfg.Gateway.OpenAIWS.PrewarmCooldownMS) + } + if cfg.Gateway.OpenAIWS.RetryBackoffInitialMS != 120 { + t.Fatalf("Gateway.OpenAIWS.RetryBackoffInitialMS = %d, want 120", cfg.Gateway.OpenAIWS.RetryBackoffInitialMS) + } + if cfg.Gateway.OpenAIWS.RetryBackoffMaxMS != 2000 { + t.Fatalf("Gateway.OpenAIWS.RetryBackoffMaxMS = %d, want 2000", cfg.Gateway.OpenAIWS.RetryBackoffMaxMS) + } + if cfg.Gateway.OpenAIWS.RetryJitterRatio != 0.2 { + t.Fatalf("Gateway.OpenAIWS.RetryJitterRatio = %v, want 0.2", cfg.Gateway.OpenAIWS.RetryJitterRatio) + } + if cfg.Gateway.OpenAIWS.RetryTotalBudgetMS != 5000 { + t.Fatalf("Gateway.OpenAIWS.RetryTotalBudgetMS = %d, want 5000", cfg.Gateway.OpenAIWS.RetryTotalBudgetMS) + } + if cfg.Gateway.OpenAIWS.PayloadLogSampleRate != 0.2 { + t.Fatalf("Gateway.OpenAIWS.PayloadLogSampleRate = %v, want 0.2", cfg.Gateway.OpenAIWS.PayloadLogSampleRate) + } + if !cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { + t.Fatalf("Gateway.OpenAIWS.StoreDisabledForceNewConn = false, want true") + } + if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" { + t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict") + } + if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled { + t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false") + } + if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" { + t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared") + } +} + +func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0") + t.Setenv("GATEWAY_OPENAI_WS_STICKY_PREVIOUS_RESPONSE_TTL_SECONDS", "7200") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 7200 { + t.Fatalf("StickyResponseIDTTLSeconds = %d, want 7200", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds) + } +} + +func TestLoadDefaultIdempotencyConfig(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Idempotency.ObserveOnly { + t.Fatalf("Idempotency.ObserveOnly = false, want true") + } + if cfg.Idempotency.DefaultTTLSeconds != 86400 { + t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 86400", cfg.Idempotency.DefaultTTLSeconds) + } + if cfg.Idempotency.SystemOperationTTLSeconds != 3600 { + t.Fatalf("Idempotency.SystemOperationTTLSeconds = %d, want 3600", cfg.Idempotency.SystemOperationTTLSeconds) + } +} + +func TestLoadIdempotencyConfigFromEnv(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("IDEMPOTENCY_OBSERVE_ONLY", "false") + t.Setenv("IDEMPOTENCY_DEFAULT_TTL_SECONDS", "600") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + if cfg.Idempotency.ObserveOnly { + t.Fatalf("Idempotency.ObserveOnly = true, want false") + } + if cfg.Idempotency.DefaultTTLSeconds != 600 { + t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 600", cfg.Idempotency.DefaultTTLSeconds) + } +} + func TestLoadSchedulingConfigFromEnv(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5") cfg, err := Load() @@ -71,7 +224,7 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) { } func TestLoadDefaultSecurityToggles(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -87,13 +240,69 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { if !cfg.Security.URLAllowlist.AllowPrivateHosts { t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true") } - if cfg.Security.ResponseHeaders.Enabled { - t.Fatalf("ResponseHeaders.Enabled = true, want false") + if !cfg.Security.ResponseHeaders.Enabled { + t.Fatalf("ResponseHeaders.Enabled = false, want true") + } +} + +func TestLoadDefaultServerMode(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Server.Mode != "release" { + t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release") + } +} + +func TestLoadDefaultJWTAccessTokenExpireMinutes(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.JWT.ExpireHour != 24 { + t.Fatalf("JWT.ExpireHour = %d, want 24", cfg.JWT.ExpireHour) + } + if cfg.JWT.AccessTokenExpireMinutes != 0 { + t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 0", cfg.JWT.AccessTokenExpireMinutes) + } +} + +func TestLoadJWTAccessTokenExpireMinutesFromEnv(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "90") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.JWT.AccessTokenExpireMinutes != 90 { + t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 90", cfg.JWT.AccessTokenExpireMinutes) + } +} + +func TestLoadDefaultDatabaseSSLMode(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Database.SSLMode != "prefer" { + t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer") } } func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -118,7 +327,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { } func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -143,7 +352,7 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { } func TestLoadDefaultDashboardCacheConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -168,7 +377,7 @@ func TestLoadDefaultDashboardCacheConfig(t *testing.T) { } func TestValidateDashboardCacheConfigEnabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -188,7 +397,7 @@ func TestValidateDashboardCacheConfigEnabled(t *testing.T) { } func TestValidateDashboardCacheConfigDisabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -207,7 +416,7 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) { } func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -244,7 +453,7 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { } func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -263,7 +472,7 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { } func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -282,7 +491,7 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { } func TestLoadDefaultUsageCleanupConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -307,7 +516,7 @@ func TestLoadDefaultUsageCleanupConfig(t *testing.T) { } func TestValidateUsageCleanupConfigEnabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -326,7 +535,7 @@ func TestValidateUsageCleanupConfigEnabled(t *testing.T) { } func TestValidateUsageCleanupConfigDisabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -424,6 +633,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) { } } +func TestValidateServerFrontendURL(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Server.FrontendURL = "https://example.com" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() frontend_url valid error: %v", err) + } + + cfg.Server.FrontendURL = "https://example.com/path" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() frontend_url with path valid error: %v", err) + } + + cfg.Server.FrontendURL = "https://example.com?utm=1" + if err := cfg.Validate(); err == nil { + t.Fatalf("Validate() should reject server.frontend_url with query") + } + + cfg.Server.FrontendURL = "https://user:pass@example.com" + if err := cfg.Validate(); err == nil { + t.Fatalf("Validate() should reject server.frontend_url with userinfo") + } + + cfg.Server.FrontendURL = "/relative" + if err := cfg.Validate(); err == nil { + t.Fatalf("Validate() should reject relative server.frontend_url") + } +} + func TestValidateFrontendRedirectURL(t *testing.T) { if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil { t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err) @@ -445,6 +688,7 @@ func TestValidateFrontendRedirectURL(t *testing.T) { func TestWarnIfInsecureURL(t *testing.T) { warnIfInsecureURL("test", "http://example.com") warnIfInsecureURL("test", "bad://url") + warnIfInsecureURL("test", "://invalid") } func TestGenerateJWTSecretDefaultLength(t *testing.T) { @@ -458,7 +702,7 @@ func TestGenerateJWTSecretDefaultLength(t *testing.T) { } func TestValidateOpsCleanupScheduleRequired(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -476,7 +720,7 @@ func TestValidateOpsCleanupScheduleRequired(t *testing.T) { } func TestValidateConcurrencyPingInterval(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -493,14 +737,14 @@ func TestValidateConcurrencyPingInterval(t *testing.T) { } func TestProvideConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) if _, err := ProvideConfig(); err != nil { t.Fatalf("ProvideConfig() error: %v", err) } } func TestValidateConfigWithLinuxDoEnabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -544,6 +788,24 @@ func TestGenerateJWTSecretWithLength(t *testing.T) { } } +func TestDatabaseDSNWithTimezone_WithPassword(t *testing.T) { + d := &DatabaseConfig{ + Host: "localhost", + Port: 5432, + User: "u", + Password: "p", + DBName: "db", + SSLMode: "prefer", + } + got := d.DSNWithTimezone("UTC") + if !strings.Contains(got, "password=p") { + t.Fatalf("DSNWithTimezone should include password: %q", got) + } + if !strings.Contains(got, "TimeZone=UTC") { + t.Fatalf("DSNWithTimezone should include TimeZone=UTC: %q", got) + } +} + func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) { if err := ValidateAbsoluteHTTPURL("https://"); err == nil { t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host") @@ -566,10 +828,35 @@ func TestWarnIfInsecureURLHTTPS(t *testing.T) { warnIfInsecureURL("secure", "https://example.com") } +func TestValidateJWTSecret_UTF8Bytes(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + // 31 bytes (< 32) even though it's 31 characters. + cfg.JWT.Secret = strings.Repeat("a", 31) + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() should reject 31-byte secret") + } + if !strings.Contains(err.Error(), "at least 32 bytes") { + t.Fatalf("Validate() error = %v", err) + } + + // 32 bytes OK. + cfg.JWT.Secret = strings.Repeat("a", 32) + err = cfg.Validate() + if err != nil { + t.Fatalf("Validate() should accept 32-byte secret: %v", err) + } +} + func TestValidateConfigErrors(t *testing.T) { buildValid := func(t *testing.T) *Config { t.Helper() - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { t.Fatalf("Load() error: %v", err) @@ -582,6 +869,26 @@ func TestValidateConfigErrors(t *testing.T) { mutate func(*Config) wantErr string }{ + { + name: "jwt secret required", + mutate: func(c *Config) { c.JWT.Secret = "" }, + wantErr: "jwt.secret is required", + }, + { + name: "jwt secret min bytes", + mutate: func(c *Config) { c.JWT.Secret = strings.Repeat("a", 31) }, + wantErr: "jwt.secret must be at least 32 bytes", + }, + { + name: "subscription maintenance worker_count non-negative", + mutate: func(c *Config) { c.SubscriptionMaintenance.WorkerCount = -1 }, + wantErr: "subscription_maintenance.worker_count", + }, + { + name: "subscription maintenance queue_size non-negative", + mutate: func(c *Config) { c.SubscriptionMaintenance.QueueSize = -1 }, + wantErr: "subscription_maintenance.queue_size", + }, { name: "jwt expire hour positive", mutate: func(c *Config) { c.JWT.ExpireHour = 0 }, @@ -592,6 +899,11 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.JWT.ExpireHour = 200 }, wantErr: "jwt.expire_hour must be <= 168", }, + { + name: "jwt access token expire minutes non-negative", + mutate: func(c *Config) { c.JWT.AccessTokenExpireMinutes = -1 }, + wantErr: "jwt.access_token_expire_minutes must be non-negative", + }, { name: "csp policy required", mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" }, @@ -779,6 +1091,16 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 }, wantErr: "gateway.stream_keepalive_interval", }, + { + name: "gateway openai ws oauth max conns factor", + mutate: func(c *Config) { c.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0 }, + wantErr: "gateway.openai_ws.oauth_max_conns_factor", + }, + { + name: "gateway openai ws apikey max conns factor", + mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 }, + wantErr: "gateway.openai_ws.apikey_max_conns_factor", + }, { name: "gateway stream data interval range", mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 }, @@ -799,6 +1121,84 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 }, wantErr: "gateway.max_line_size must be non-negative", }, + { + name: "gateway usage record worker count", + mutate: func(c *Config) { c.Gateway.UsageRecord.WorkerCount = 0 }, + wantErr: "gateway.usage_record.worker_count", + }, + { + name: "gateway usage record queue size", + mutate: func(c *Config) { c.Gateway.UsageRecord.QueueSize = 0 }, + wantErr: "gateway.usage_record.queue_size", + }, + { + name: "gateway usage record timeout", + mutate: func(c *Config) { c.Gateway.UsageRecord.TaskTimeoutSeconds = 0 }, + wantErr: "gateway.usage_record.task_timeout_seconds", + }, + { + name: "gateway usage record overflow policy", + mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowPolicy = "invalid" }, + wantErr: "gateway.usage_record.overflow_policy", + }, + { + name: "gateway usage record sample percent range", + mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowSamplePercent = 101 }, + wantErr: "gateway.usage_record.overflow_sample_percent", + }, + { + name: "gateway usage record sample percent required for sample policy", + mutate: func(c *Config) { + c.Gateway.UsageRecord.OverflowPolicy = UsageRecordOverflowPolicySample + c.Gateway.UsageRecord.OverflowSamplePercent = 0 + }, + wantErr: "gateway.usage_record.overflow_sample_percent must be positive", + }, + { + name: "gateway usage record auto scale max gte min", + mutate: func(c *Config) { + c.Gateway.UsageRecord.AutoScaleMinWorkers = 256 + c.Gateway.UsageRecord.AutoScaleMaxWorkers = 128 + }, + wantErr: "gateway.usage_record.auto_scale_max_workers", + }, + { + name: "gateway usage record worker in auto scale range", + mutate: func(c *Config) { + c.Gateway.UsageRecord.AutoScaleMinWorkers = 200 + c.Gateway.UsageRecord.AutoScaleMaxWorkers = 300 + c.Gateway.UsageRecord.WorkerCount = 128 + }, + wantErr: "gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers", + }, + { + name: "gateway usage record auto scale queue thresholds order", + mutate: func(c *Config) { + c.Gateway.UsageRecord.AutoScaleUpQueuePercent = 50 + c.Gateway.UsageRecord.AutoScaleDownQueuePercent = 50 + }, + wantErr: "gateway.usage_record.auto_scale_down_queue_percent must be less", + }, + { + name: "gateway usage record auto scale up step", + mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleUpStep = 0 }, + wantErr: "gateway.usage_record.auto_scale_up_step", + }, + { + name: "gateway usage record auto scale interval", + mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 }, + wantErr: "gateway.usage_record.auto_scale_check_interval_seconds", + }, + { + name: "gateway user group rate cache ttl", + mutate: func(c *Config) { c.Gateway.UserGroupRateCacheTTLSeconds = 0 }, + wantErr: "gateway.user_group_rate_cache_ttl_seconds", + }, + { + name: "gateway models list cache ttl range", + mutate: func(c *Config) { c.Gateway.ModelsListCacheTTLSeconds = 31 }, + wantErr: "gateway.models_list_cache_ttl_seconds", + }, { name: "gateway scheduling sticky waiting", mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 }, @@ -822,6 +1222,37 @@ func TestValidateConfigErrors(t *testing.T) { }, wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds", }, + { + name: "log level invalid", + mutate: func(c *Config) { c.Log.Level = "trace" }, + wantErr: "log.level", + }, + { + name: "log format invalid", + mutate: func(c *Config) { c.Log.Format = "plain" }, + wantErr: "log.format", + }, + { + name: "log output disabled", + mutate: func(c *Config) { + c.Log.Output.ToStdout = false + c.Log.Output.ToFile = false + }, + wantErr: "log.output.to_stdout and log.output.to_file cannot both be false", + }, + { + name: "log rotation size", + mutate: func(c *Config) { c.Log.Rotation.MaxSizeMB = 0 }, + wantErr: "log.rotation.max_size_mb", + }, + { + name: "log sampling enabled invalid", + mutate: func(c *Config) { + c.Log.Sampling.Enabled = true + c.Log.Sampling.Initial = 0 + }, + wantErr: "log.sampling.initial", + }, { name: "ops metrics collector ttl", mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 }, @@ -850,3 +1281,393 @@ func TestValidateConfigErrors(t *testing.T) { }) } } + +func TestValidateConfig_OpenAIWSRules(t *testing.T) { + buildValid := func(t *testing.T) *Config { + t.Helper() + resetViperWithJWTSecret(t) + cfg, err := Load() + require.NoError(t, err) + return cfg + } + + t.Run("sticky response id ttl 兼容旧键回填", func(t *testing.T) { + cfg := buildValid(t) + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0 + cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 7200 + + require.NoError(t, cfg.Validate()) + require.Equal(t, 7200, cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds) + }) + + cases := []struct { + name string + mutate func(*Config) + wantErr string + }{ + { + name: "max_conns_per_account 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxConnsPerAccount = 0 }, + wantErr: "gateway.openai_ws.max_conns_per_account", + }, + { + name: "min_idle_per_account 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.MinIdlePerAccount = -1 }, + wantErr: "gateway.openai_ws.min_idle_per_account", + }, + { + name: "max_idle_per_account 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxIdlePerAccount = -1 }, + wantErr: "gateway.openai_ws.max_idle_per_account", + }, + { + name: "min_idle_per_account 不能大于 max_idle_per_account", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.MinIdlePerAccount = 3 + c.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + }, + wantErr: "gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account", + }, + { + name: "max_idle_per_account 不能大于 max_conns_per_account", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + c.Gateway.OpenAIWS.MinIdlePerAccount = 1 + c.Gateway.OpenAIWS.MaxIdlePerAccount = 3 + }, + wantErr: "gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account", + }, + { + name: "dial_timeout_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.DialTimeoutSeconds = 0 }, + wantErr: "gateway.openai_ws.dial_timeout_seconds", + }, + { + name: "read_timeout_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.ReadTimeoutSeconds = 0 }, + wantErr: "gateway.openai_ws.read_timeout_seconds", + }, + { + name: "write_timeout_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.WriteTimeoutSeconds = 0 }, + wantErr: "gateway.openai_ws.write_timeout_seconds", + }, + { + name: "pool_target_utilization 必须在 (0,1]", + mutate: func(c *Config) { c.Gateway.OpenAIWS.PoolTargetUtilization = 0 }, + wantErr: "gateway.openai_ws.pool_target_utilization", + }, + { + name: "queue_limit_per_conn 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.QueueLimitPerConn = 0 }, + wantErr: "gateway.openai_ws.queue_limit_per_conn", + }, + { + name: "fallback_cooldown_seconds 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.FallbackCooldownSeconds = -1 }, + wantErr: "gateway.openai_ws.fallback_cooldown_seconds", + }, + { + name: "store_disabled_conn_mode 必须为 strict|adaptive|off", + mutate: func(c *Config) { c.Gateway.OpenAIWS.StoreDisabledConnMode = "invalid" }, + wantErr: "gateway.openai_ws.store_disabled_conn_mode", + }, + { + name: "ingress_mode_default 必须为 off|shared|dedicated", + mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" }, + wantErr: "gateway.openai_ws.ingress_mode_default", + }, + { + name: "payload_log_sample_rate 必须在 [0,1] 范围内", + mutate: func(c *Config) { c.Gateway.OpenAIWS.PayloadLogSampleRate = 1.2 }, + wantErr: "gateway.openai_ws.payload_log_sample_rate", + }, + { + name: "retry_total_budget_ms 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 }, + wantErr: "gateway.openai_ws.retry_total_budget_ms", + }, + { + name: "lb_top_k 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 }, + wantErr: "gateway.openai_ws.lb_top_k", + }, + { + name: "sticky_session_ttl_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.StickySessionTTLSeconds = 0 }, + wantErr: "gateway.openai_ws.sticky_session_ttl_seconds", + }, + { + name: "sticky_response_id_ttl_seconds 必须为正数", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0 + c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 0 + }, + wantErr: "gateway.openai_ws.sticky_response_id_ttl_seconds", + }, + { + name: "sticky_previous_response_ttl_seconds 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = -1 }, + wantErr: "gateway.openai_ws.sticky_previous_response_ttl_seconds", + }, + { + name: "scheduler_score_weights 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = -0.1 }, + wantErr: "gateway.openai_ws.scheduler_score_weights.* must be non-negative", + }, + { + name: "scheduler_score_weights 不能全为 0", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0 + }, + wantErr: "gateway.openai_ws.scheduler_score_weights must not all be zero", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + cfg := buildValid(t) + tc.mutate(cfg) + + err := cfg.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), tc.wantErr) + }) + } +} + +func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Gateway.UsageRecord.AutoScaleEnabled = false + cfg.Gateway.UsageRecord.WorkerCount = 64 + + // 自动扩缩容关闭时,这些字段应被忽略,不应导致校验失败。 + cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 0 + cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 0 + cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent = 0 + cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent = 100 + cfg.Gateway.UsageRecord.AutoScaleUpStep = 0 + cfg.Gateway.UsageRecord.AutoScaleDownStep = 0 + cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 + cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds = -1 + + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() should ignore auto scale fields when disabled: %v", err) + } +} + +func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) { + resetViperWithJWTSecret(t) + + cases := []struct { + name string + mutate func(*Config) + wantErr string + }{ + { + name: "log level required", + mutate: func(c *Config) { + c.Log.Level = "" + }, + wantErr: "log.level is required", + }, + { + name: "log format required", + mutate: func(c *Config) { + c.Log.Format = "" + }, + wantErr: "log.format is required", + }, + { + name: "log stacktrace required", + mutate: func(c *Config) { + c.Log.StacktraceLevel = "" + }, + wantErr: "log.stacktrace_level is required", + }, + { + name: "log max backups non-negative", + mutate: func(c *Config) { + c.Log.Rotation.MaxBackups = -1 + }, + wantErr: "log.rotation.max_backups must be non-negative", + }, + { + name: "log max age non-negative", + mutate: func(c *Config) { + c.Log.Rotation.MaxAgeDays = -1 + }, + wantErr: "log.rotation.max_age_days must be non-negative", + }, + { + name: "sampling thereafter non-negative when disabled", + mutate: func(c *Config) { + c.Log.Sampling.Enabled = false + c.Log.Sampling.Thereafter = -1 + }, + wantErr: "log.sampling.thereafter must be non-negative", + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + tt.mutate(cfg) + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr) + } + }) + } +} + +func TestSoraCurlCFFISidecarDefaults(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Sora.Client.CurlCFFISidecar.Enabled { + t.Fatalf("Sora curl_cffi sidecar should be enabled by default") + } + if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 { + t.Fatalf("Sora cloudflare challenge cooldown should be positive by default") + } + if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" { + t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default") + } + if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" { + t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default") + } + if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled { + t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default") + } + if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 { + t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default") + } +} + +func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CurlCFFISidecar.Enabled = false + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") { + t.Fatalf("Validate() error = %v, want sidecar enabled error", err) + } +} + +func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CurlCFFISidecar.BaseURL = " " + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") { + t.Fatalf("Validate() error = %v, want sidecar base_url required error", err) + } +} + +func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1 + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") { + t.Fatalf("Validate() error = %v, want sidecar session ttl error", err) + } +} + +func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1 + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") { + t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err) + } +} + +func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + if cfg.Gateway.UsageRecord.WorkerCount != 128 { + t.Fatalf("worker_count = %d, want 128", cfg.Gateway.UsageRecord.WorkerCount) + } + if cfg.Gateway.UsageRecord.QueueSize != 16384 { + t.Fatalf("queue_size = %d, want 16384", cfg.Gateway.UsageRecord.QueueSize) + } + if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 { + t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds) + } + if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample { + t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample) + } + if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 { + t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent) + } + if !cfg.Gateway.UsageRecord.AutoScaleEnabled { + t.Fatalf("auto_scale_enabled = false, want true") + } + if cfg.Gateway.UsageRecord.AutoScaleMinWorkers != 128 { + t.Fatalf("auto_scale_min_workers = %d, want 128", cfg.Gateway.UsageRecord.AutoScaleMinWorkers) + } + if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers != 512 { + t.Fatalf("auto_scale_max_workers = %d, want 512", cfg.Gateway.UsageRecord.AutoScaleMaxWorkers) + } + if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent != 70 { + t.Fatalf("auto_scale_up_queue_percent = %d, want 70", cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent) + } + if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent != 15 { + t.Fatalf("auto_scale_down_queue_percent = %d, want 15", cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent) + } + if cfg.Gateway.UsageRecord.AutoScaleUpStep != 32 { + t.Fatalf("auto_scale_up_step = %d, want 32", cfg.Gateway.UsageRecord.AutoScaleUpStep) + } + if cfg.Gateway.UsageRecord.AutoScaleDownStep != 16 { + t.Fatalf("auto_scale_down_step = %d, want 16", cfg.Gateway.UsageRecord.AutoScaleDownStep) + } + if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds != 3 { + t.Fatalf("auto_scale_check_interval_seconds = %d, want 3", cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds) + } + if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds != 10 { + t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds) + } +} diff --git a/backend/internal/config/wire.go b/backend/internal/config/wire.go index ec26c401..bf6b3bd6 100644 --- a/backend/internal/config/wire.go +++ b/backend/internal/config/wire.go @@ -9,5 +9,5 @@ var ProviderSet = wire.NewSet( // ProvideConfig 提供应用配置 func ProvideConfig() (*Config, error) { - return Load() + return LoadForBootstrap() } diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 05b5adc1..d7bb50fc 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -22,6 +22,7 @@ const ( PlatformOpenAI = "openai" PlatformGemini = "gemini" PlatformAntigravity = "antigravity" + PlatformSora = "sora" ) // Account type constants @@ -73,6 +74,7 @@ var DefaultAntigravityModelMapping = map[string]string{ "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型 "claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射 "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型 + "claude-sonnet-4-6": "claude-sonnet-4-6", "claude-sonnet-4-5": "claude-sonnet-4-5", "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", // Claude 详细版本 ID 映射 @@ -87,14 +89,24 @@ var DefaultAntigravityModelMapping = map[string]string{ "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", "gemini-2.5-pro": "gemini-2.5-pro", // Gemini 3 白名单 - "gemini-3-flash": "gemini-3-flash", - "gemini-3-pro-high": "gemini-3-pro-high", - "gemini-3-pro-low": "gemini-3-pro-low", - "gemini-3-pro-image": "gemini-3-pro-image", + "gemini-3-flash": "gemini-3-flash", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", // Gemini 3 preview 映射 - "gemini-3-flash-preview": "gemini-3-flash", - "gemini-3-pro-preview": "gemini-3-pro-high", - "gemini-3-pro-image-preview": "gemini-3-pro-image", + "gemini-3-flash-preview": "gemini-3-flash", + "gemini-3-pro-preview": "gemini-3-pro-high", + // Gemini 3.1 白名单 + "gemini-3.1-pro-high": "gemini-3.1-pro-high", + "gemini-3.1-pro-low": "gemini-3.1-pro-low", + // Gemini 3.1 preview 映射 + "gemini-3.1-pro-preview": "gemini-3.1-pro-high", + // Gemini 3.1 image 白名单 + "gemini-3.1-flash-image": "gemini-3.1-flash-image", + // Gemini 3.1 image preview 映射 + "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", + // Gemini 3 image 兼容映射(向 3.1 image 迁移) + "gemini-3-pro-image": "gemini-3.1-flash-image", + "gemini-3-pro-image-preview": "gemini-3.1-flash-image", // 其他官方模型 "gpt-oss-120b-medium": "gpt-oss-120b-medium", "tab_flash_lite_preview": "tab_flash_lite_preview", diff --git a/backend/internal/domain/constants_test.go b/backend/internal/domain/constants_test.go new file mode 100644 index 00000000..29605ac6 --- /dev/null +++ b/backend/internal/domain/constants_test.go @@ -0,0 +1,24 @@ +package domain + +import "testing" + +func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) { + t.Parallel() + + cases := map[string]string{ + "gemini-3.1-flash-image": "gemini-3.1-flash-image", + "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", + "gemini-3-pro-image": "gemini-3.1-flash-image", + "gemini-3-pro-image-preview": "gemini-3.1-flash-image", + } + + for from, want := range cases { + got, ok := DefaultAntigravityModelMapping[from] + if !ok { + t.Fatalf("expected mapping for %q to exist", from) + } + if got != want { + t.Fatalf("unexpected mapping for %q: got %q want %q", from, got, want) + } + } +} diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index b5d1dd0a..4ce17219 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) { return } - dataPayload := req.Data - if err := validateDataHeader(dataPayload); err != nil { + if err := validateDataHeader(req.Data); err != nil { response.BadRequest(c, err.Error()) return } + executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + return h.importData(ctx, req) + }) +} + +func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) { skipDefaultGroupBind := true if req.SkipDefaultGroupBind != nil { skipDefaultGroupBind = *req.SkipDefaultGroupBind } + dataPayload := req.Data result := DataImportResult{} - existingProxies, err := h.listAllProxies(c.Request.Context()) + + existingProxies, err := h.listAllProxies(ctx) if err != nil { - response.ErrorFrom(c, err) - return + return result, err } proxyKeyToID := make(map[string]int64, len(existingProxies)) @@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) { proxyKeyToID[key] = existingID result.ProxyReused++ if normalizedStatus != "" { - if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus { - _, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{ + if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus { + _, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{ Status: normalizedStatus, }) } @@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { continue } - created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ + created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{ Name: defaultProxyName(item.Name), Protocol: item.Protocol, Host: item.Host, @@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) { Username: item.Username, Password: item.Password, }) - if err != nil { + if createErr != nil { result.ProxyFailed++ result.Errors = append(result.Errors, DataImportError{ Kind: "proxy", Name: item.Name, ProxyKey: key, - Message: err.Error(), + Message: createErr.Error(), }) continue } @@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { result.ProxyCreated++ if normalizedStatus != "" && normalizedStatus != created.Status { - _, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{ + _, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{ Status: normalizedStatus, }) } @@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { SkipDefaultGroupBind: skipDefaultGroupBind, } - if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil { + if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil { result.AccountFailed++ result.Errors = append(result.Errors, DataImportError{ Kind: "account", @@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { result.AccountCreated++ } - response.Success(c, result) + return result, nil } func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) { @@ -341,7 +347,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc pageSize := dataPageCap var out []service.Account for { - items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search) + items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0) if err != nil { return nil, err } diff --git a/backend/internal/handler/admin/account_data_handler_test.go b/backend/internal/handler/admin/account_data_handler_test.go index c8b04c2a..285033a1 100644 --- a/backend/internal/handler/admin/account_data_handler_test.go +++ b/backend/internal/handler/admin/account_data_handler_test.go @@ -64,6 +64,7 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) { nil, nil, nil, + nil, ) router.GET("/api/v1/admin/accounts/data", h.ExportData) diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 9a13b57c..98ead284 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -2,7 +2,13 @@ package admin import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" "errors" + "fmt" + "net/http" "strconv" "strings" "sync" @@ -10,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" @@ -46,6 +53,7 @@ type AccountHandler struct { concurrencyService *service.ConcurrencyService crsSyncService *service.CRSSyncService sessionLimitCache service.SessionLimitCache + rpmCache service.RPMCache tokenCacheInvalidator service.TokenCacheInvalidator } @@ -62,6 +70,7 @@ func NewAccountHandler( concurrencyService *service.ConcurrencyService, crsSyncService *service.CRSSyncService, sessionLimitCache service.SessionLimitCache, + rpmCache service.RPMCache, tokenCacheInvalidator service.TokenCacheInvalidator, ) *AccountHandler { return &AccountHandler{ @@ -76,6 +85,7 @@ func NewAccountHandler( concurrencyService: concurrencyService, crsSyncService: crsSyncService, sessionLimitCache: sessionLimitCache, + rpmCache: rpmCache, tokenCacheInvalidator: tokenCacheInvalidator, } } @@ -133,6 +143,13 @@ type BulkUpdateAccountsRequest struct { ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 } +// CheckMixedChannelRequest represents check mixed channel risk request +type CheckMixedChannelRequest struct { + Platform string `json:"platform" binding:"required"` + GroupIDs []int64 `json:"group_ids"` + AccountID *int64 `json:"account_id"` +} + // AccountWithConcurrency extends Account with real-time concurrency info type AccountWithConcurrency struct { *dto.Account @@ -140,6 +157,51 @@ type AccountWithConcurrency struct { // 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回 CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用 ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数 + CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数 +} + +func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency { + item := AccountWithConcurrency{ + Account: dto.AccountFromService(account), + CurrentConcurrency: 0, + } + if account == nil { + return item + } + + if h.concurrencyService != nil { + if counts, err := h.concurrencyService.GetAccountConcurrencyBatch(ctx, []int64{account.ID}); err == nil { + item.CurrentConcurrency = counts[account.ID] + } + } + + if account.IsAnthropicOAuthOrSetupToken() { + if h.accountUsageService != nil && account.GetWindowCostLimit() > 0 { + startTime := account.GetCurrentWindowStartTime() + if stats, err := h.accountUsageService.GetAccountWindowStats(ctx, account.ID, startTime); err == nil && stats != nil { + cost := stats.StandardCost + item.CurrentWindowCost = &cost + } + } + + if h.sessionLimitCache != nil && account.GetMaxSessions() > 0 { + idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute + idleTimeouts := map[int64]time.Duration{account.ID: idleTimeout} + if sessions, err := h.sessionLimitCache.GetActiveSessionCountBatch(ctx, []int64{account.ID}, idleTimeouts); err == nil { + if count, ok := sessions[account.ID]; ok { + item.ActiveSessions = &count + } + } + } + + if h.rpmCache != nil && account.GetBaseRPM() > 0 { + if rpm, err := h.rpmCache.GetRPM(ctx, account.ID); err == nil { + item.CurrentRPM = &rpm + } + } + } + + return item } // List handles listing all accounts with pagination @@ -156,7 +218,12 @@ func (h *AccountHandler) List(c *gin.Context) { search = search[:100] } - accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) + var groupID int64 + if groupIDStr := c.Query("group"); groupIDStr != "" { + groupID, _ = strconv.ParseInt(groupIDStr, 10, 64) + } + + accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID) if err != nil { response.ErrorFrom(c, err) return @@ -174,9 +241,10 @@ func (h *AccountHandler) List(c *gin.Context) { concurrencyCounts = make(map[int64]int) } - // 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能) + // 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能) windowCostAccountIDs := make([]int64, 0) sessionLimitAccountIDs := make([]int64, 0) + rpmAccountIDs := make([]int64, 0) sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置 for i := range accounts { acc := &accounts[i] @@ -188,12 +256,24 @@ func (h *AccountHandler) List(c *gin.Context) { sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID) sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute } + if acc.GetBaseRPM() > 0 { + rpmAccountIDs = append(rpmAccountIDs, acc.ID) + } } } - // 并行获取窗口费用和活跃会话数 + // 并行获取窗口费用、活跃会话数和 RPM 计数 var windowCosts map[int64]float64 var activeSessions map[int64]int + var rpmCounts map[int64]int + + // 获取 RPM 计数(批量查询) + if len(rpmAccountIDs) > 0 && h.rpmCache != nil { + rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs) + if rpmCounts == nil { + rpmCounts = make(map[int64]int) + } + } // 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置) if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil { @@ -254,12 +334,81 @@ func (h *AccountHandler) List(c *gin.Context) { } } + // 添加 RPM 计数(仅当启用时) + if rpmCounts != nil { + if rpm, ok := rpmCounts[acc.ID]; ok { + item.CurrentRPM = &rpm + } + } + result[i] = item } + etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search) + if etag != "" { + c.Header("ETag", etag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), etag) { + c.Status(http.StatusNotModified) + return + } + } + response.Paginated(c, result, total, page, pageSize) } +func buildAccountsListETag( + items []AccountWithConcurrency, + total int64, + page, pageSize int, + platform, accountType, status, search string, +) string { + payload := struct { + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` + Platform string `json:"platform"` + AccountType string `json:"type"` + Status string `json:"status"` + Search string `json:"search"` + Items []AccountWithConcurrency `json:"items"` + }{ + Total: total, + Page: page, + PageSize: pageSize, + Platform: platform, + AccountType: accountType, + Status: status, + Search: search, + Items: items, + } + raw, err := json.Marshal(payload) + if err != nil { + return "" + } + sum := sha256.Sum256(raw) + return "\"" + hex.EncodeToString(sum[:]) + "\"" +} + +func ifNoneMatchMatched(ifNoneMatch, etag string) bool { + if etag == "" || ifNoneMatch == "" { + return false + } + for _, token := range strings.Split(ifNoneMatch, ",") { + candidate := strings.TrimSpace(token) + if candidate == "*" { + return true + } + if candidate == etag { + return true + } + if strings.HasPrefix(candidate, "W/") && strings.TrimPrefix(candidate, "W/") == etag { + return true + } + } + return false +} + // GetByID handles getting an account by ID // GET /api/v1/admin/accounts/:id func (h *AccountHandler) GetByID(c *gin.Context) { @@ -275,7 +424,51 @@ func (h *AccountHandler) GetByID(c *gin.Context) { return } - response.Success(c, dto.AccountFromService(account)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) +} + +// CheckMixedChannel handles checking mixed channel risk for account-group binding. +// POST /api/v1/admin/accounts/check-mixed-channel +func (h *AccountHandler) CheckMixedChannel(c *gin.Context) { + var req CheckMixedChannelRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if len(req.GroupIDs) == 0 { + response.Success(c, gin.H{"has_risk": false}) + return + } + + accountID := int64(0) + if req.AccountID != nil { + accountID = *req.AccountID + } + + err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs) + if err != nil { + var mixedErr *service.MixedChannelError + if errors.As(err, &mixedErr) { + response.Success(c, gin.H{ + "has_risk": true, + "error": "mixed_channel_warning", + "message": mixedErr.Error(), + "details": gin.H{ + "group_id": mixedErr.GroupID, + "group_name": mixedErr.GroupName, + "current_platform": mixedErr.CurrentPlatform, + "other_platform": mixedErr.OtherPlatform, + }, + }) + return + } + + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"has_risk": false}) } // Create handles creating a new account @@ -290,50 +483,57 @@ func (h *AccountHandler) Create(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk - account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ - Name: req.Name, - Notes: req.Notes, - Platform: req.Platform, - Type: req.Type, - Credentials: req.Credentials, - Extra: req.Extra, - ProxyID: req.ProxyID, - Concurrency: req.Concurrency, - Priority: req.Priority, - RateMultiplier: req.RateMultiplier, - GroupIDs: req.GroupIDs, - ExpiresAt: req.ExpiresAt, - AutoPauseOnExpired: req.AutoPauseOnExpired, - SkipMixedChannelCheck: skipCheck, + result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ + Name: req.Name, + Notes: req.Notes, + Platform: req.Platform, + Type: req.Type, + Credentials: req.Credentials, + Extra: req.Extra, + ProxyID: req.ProxyID, + Concurrency: req.Concurrency, + Priority: req.Priority, + RateMultiplier: req.RateMultiplier, + GroupIDs: req.GroupIDs, + ExpiresAt: req.ExpiresAt, + AutoPauseOnExpired: req.AutoPauseOnExpired, + SkipMixedChannelCheck: skipCheck, + }) + if execErr != nil { + return nil, execErr + } + return h.buildAccountResponseWithRuntime(ctx, account), nil }) if err != nil { // 检查是否为混合渠道错误 var mixedErr *service.MixedChannelError if errors.As(err, &mixedErr) { - // 返回特殊错误码要求确认 + // 创建接口仅返回最小必要字段,详细信息由专门检查接口提供 c.JSON(409, gin.H{ "error": "mixed_channel_warning", "message": mixedErr.Error(), - "details": gin.H{ - "group_id": mixedErr.GroupID, - "group_name": mixedErr.GroupName, - "current_platform": mixedErr.CurrentPlatform, - "other_platform": mixedErr.OtherPlatform, - }, - "require_confirmation": true, }) return } + if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } response.ErrorFrom(c, err) return } - response.Success(c, dto.AccountFromService(account)) + if result != nil && result.Replayed { + c.Header("X-Idempotency-Replayed", "true") + } + response.Success(c, result.Data) } // Update handles updating an account @@ -354,6 +554,8 @@ func (h *AccountHandler) Update(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk @@ -378,17 +580,10 @@ func (h *AccountHandler) Update(c *gin.Context) { // 检查是否为混合渠道错误 var mixedErr *service.MixedChannelError if errors.As(err, &mixedErr) { - // 返回特殊错误码要求确认 + // 更新接口仅返回最小必要字段,详细信息由专门检查接口提供 c.JSON(409, gin.H{ "error": "mixed_channel_warning", "message": mixedErr.Error(), - "details": gin.H{ - "group_id": mixedErr.GroupID, - "group_name": mixedErr.GroupName, - "current_platform": mixedErr.CurrentPlatform, - "other_platform": mixedErr.OtherPlatform, - }, - "require_confirmation": true, }) return } @@ -397,7 +592,7 @@ func (h *AccountHandler) Update(c *gin.Context) { return } - response.Success(c, dto.AccountFromService(account)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // Delete handles deleting an account @@ -424,10 +619,17 @@ type TestAccountRequest struct { } type SyncFromCRSRequest struct { - BaseURL string `json:"base_url" binding:"required"` - Username string `json:"username" binding:"required"` - Password string `json:"password" binding:"required"` - SyncProxies *bool `json:"sync_proxies"` + BaseURL string `json:"base_url" binding:"required"` + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` + SyncProxies *bool `json:"sync_proxies"` + SelectedAccountIDs []string `json:"selected_account_ids"` +} + +type PreviewFromCRSRequest struct { + BaseURL string `json:"base_url" binding:"required"` + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` } // Test handles testing account connectivity with SSE streaming @@ -466,10 +668,11 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) { } result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{ - BaseURL: req.BaseURL, - Username: req.Username, - Password: req.Password, - SyncProxies: syncProxies, + BaseURL: req.BaseURL, + Username: req.Username, + Password: req.Password, + SyncProxies: syncProxies, + SelectedAccountIDs: req.SelectedAccountIDs, }) if err != nil { // Provide detailed error message for CRS sync failures @@ -480,6 +683,28 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) { response.Success(c, result) } +// PreviewFromCRS handles previewing accounts from CRS before sync +// POST /api/v1/admin/accounts/sync/crs/preview +func (h *AccountHandler) PreviewFromCRS(c *gin.Context) { + var req PreviewFromCRSRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + result, err := h.crsSyncService.PreviewFromCRS(c.Request.Context(), service.SyncFromCRSInput{ + BaseURL: req.BaseURL, + Username: req.Username, + Password: req.Password, + }) + if err != nil { + response.InternalError(c, "CRS preview failed: "+err.Error()) + return + } + + response.Success(c, result) +} + // Refresh handles refreshing account credentials // POST /api/v1/admin/accounts/:id/refresh func (h *AccountHandler) Refresh(c *gin.Context) { @@ -625,7 +850,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) { } } - response.Success(c, dto.AccountFromService(updatedAccount)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount)) } // GetStats handles getting account statistics @@ -683,7 +908,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) { } } - response.Success(c, dto.AccountFromService(account)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // BatchCreate handles batch creating accounts @@ -697,61 +922,65 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { return } - ctx := c.Request.Context() - success := 0 - failed := 0 - results := make([]gin.H, 0, len(req.Accounts)) + executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + success := 0 + failed := 0 + results := make([]gin.H, 0, len(req.Accounts)) - for _, item := range req.Accounts { - if item.RateMultiplier != nil && *item.RateMultiplier < 0 { - failed++ + for _, item := range req.Accounts { + if item.RateMultiplier != nil && *item.RateMultiplier < 0 { + failed++ + results = append(results, gin.H{ + "name": item.Name, + "success": false, + "error": "rate_multiplier must be >= 0", + }) + continue + } + + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(item.Extra) + + skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk + + account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ + Name: item.Name, + Notes: item.Notes, + Platform: item.Platform, + Type: item.Type, + Credentials: item.Credentials, + Extra: item.Extra, + ProxyID: item.ProxyID, + Concurrency: item.Concurrency, + Priority: item.Priority, + RateMultiplier: item.RateMultiplier, + GroupIDs: item.GroupIDs, + ExpiresAt: item.ExpiresAt, + AutoPauseOnExpired: item.AutoPauseOnExpired, + SkipMixedChannelCheck: skipCheck, + }) + if err != nil { + failed++ + results = append(results, gin.H{ + "name": item.Name, + "success": false, + "error": err.Error(), + }) + continue + } + success++ results = append(results, gin.H{ "name": item.Name, - "success": false, - "error": "rate_multiplier must be >= 0", + "id": account.ID, + "success": true, }) - continue } - skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk - - account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ - Name: item.Name, - Notes: item.Notes, - Platform: item.Platform, - Type: item.Type, - Credentials: item.Credentials, - Extra: item.Extra, - ProxyID: item.ProxyID, - Concurrency: item.Concurrency, - Priority: item.Priority, - RateMultiplier: item.RateMultiplier, - GroupIDs: item.GroupIDs, - ExpiresAt: item.ExpiresAt, - AutoPauseOnExpired: item.AutoPauseOnExpired, - SkipMixedChannelCheck: skipCheck, - }) - if err != nil { - failed++ - results = append(results, gin.H{ - "name": item.Name, - "success": false, - "error": err.Error(), - }) - continue - } - success++ - results = append(results, gin.H{ - "name": item.Name, - "id": account.ID, - "success": true, - }) - } - - response.Success(c, gin.H{ - "success": success, - "failed": failed, - "results": results, + return gin.H{ + "success": success, + "failed": failed, + "results": results, + }, nil }) } @@ -789,57 +1018,58 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) { } ctx := c.Request.Context() - success := 0 - failed := 0 - results := []gin.H{} + // 阶段一:预验证所有账号存在,收集 credentials + type accountUpdate struct { + ID int64 + Credentials map[string]any + } + updates := make([]accountUpdate, 0, len(req.AccountIDs)) for _, accountID := range req.AccountIDs { - // Get account account, err := h.adminService.GetAccount(ctx, accountID) if err != nil { - failed++ - results = append(results, gin.H{ - "account_id": accountID, - "success": false, - "error": "Account not found", - }) - continue + response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID)) + return } - - // Update credentials field if account.Credentials == nil { account.Credentials = make(map[string]any) } - account.Credentials[req.Field] = req.Value + updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials}) + } - // Update account - updateInput := &service.UpdateAccountInput{ - Credentials: account.Credentials, - } - - _, err = h.adminService.UpdateAccount(ctx, accountID, updateInput) - if err != nil { + // 阶段二:依次更新,返回每个账号的成功/失败明细,便于调用方重试 + success := 0 + failed := 0 + successIDs := make([]int64, 0, len(updates)) + failedIDs := make([]int64, 0, len(updates)) + results := make([]gin.H, 0, len(updates)) + for _, u := range updates { + updateInput := &service.UpdateAccountInput{Credentials: u.Credentials} + if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil { failed++ + failedIDs = append(failedIDs, u.ID) results = append(results, gin.H{ - "account_id": accountID, + "account_id": u.ID, "success": false, "error": err.Error(), }) continue } - success++ + successIDs = append(successIDs, u.ID) results = append(results, gin.H{ - "account_id": accountID, + "account_id": u.ID, "success": true, }) } response.Success(c, gin.H{ - "success": success, - "failed": failed, - "results": results, + "success": success, + "failed": failed, + "success_ids": successIDs, + "failed_ids": failedIDs, + "results": results, }) } @@ -855,6 +1085,8 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk @@ -890,6 +1122,14 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { SkipMixedChannelCheck: skipCheck, }) if err != nil { + var mixedErr *service.MixedChannelError + if errors.As(err, &mixedErr) { + c.JSON(409, gin.H{ + "error": "mixed_channel_warning", + "message": mixedErr.Error(), + }) + return + } response.ErrorFrom(c, err) return } @@ -1074,7 +1314,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) { return } - response.Success(c, gin.H{"message": "Rate limit cleared successfully"}) + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // GetTempUnschedulable handles getting temporary unschedulable status @@ -1138,6 +1384,34 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) { response.Success(c, stats) } +// BatchTodayStatsRequest 批量今日统计请求体。 +type BatchTodayStatsRequest struct { + AccountIDs []int64 `json:"account_ids" binding:"required"` +} + +// GetBatchTodayStats 批量获取多个账号的今日统计。 +// POST /api/v1/admin/accounts/today-stats/batch +func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) { + var req BatchTodayStatsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if len(req.AccountIDs) == 0 { + response.Success(c, gin.H{"stats": map[string]any{}}) + return + } + + stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"stats": stats}) +} + // SetSchedulableRequest represents the request body for setting schedulable status type SetSchedulableRequest struct { Schedulable bool `json:"schedulable"` @@ -1164,7 +1438,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) { return } - response.Success(c, dto.AccountFromService(account)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // GetAvailableModels handles getting available models for an account @@ -1261,32 +1535,14 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { // Handle Antigravity accounts: return Claude + Gemini models if account.Platform == service.PlatformAntigravity { - // Antigravity 支持 Claude 和部分 Gemini 模型 - type UnifiedModel struct { - ID string `json:"id"` - Type string `json:"type"` - DisplayName string `json:"display_name"` - } + // 直接复用 antigravity.DefaultModels(),与 /v1/models 端点保持同步 + response.Success(c, antigravity.DefaultModels()) + return + } - var models []UnifiedModel - - // 添加 Claude 模型 - for _, m := range claude.DefaultModels { - models = append(models, UnifiedModel{ - ID: m.ID, - Type: m.Type, - DisplayName: m.DisplayName, - }) - } - - // 添加 Gemini 3 系列模型用于测试 - geminiTestModels := []UnifiedModel{ - {ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash"}, - {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview"}, - } - models = append(models, geminiTestModels...) - - response.Success(c, models) + // Handle Sora accounts + if account.Platform == service.PlatformSora { + response.Success(c, service.DefaultSoraModels(nil)) return } @@ -1399,7 +1655,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { accounts := make([]*service.Account, 0) if len(req.AccountIDs) == 0 { - allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "") + allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0) if err != nil { response.ErrorFrom(c, err) return @@ -1497,3 +1753,22 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) { response.Success(c, domain.DefaultAntigravityModelMapping) } + +// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。 +// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。 +func sanitizeExtraBaseRPM(extra map[string]any) { + if extra == nil { + return + } + raw, ok := extra["base_rpm"] + if !ok { + return + } + v := service.ParseExtraInt(raw) + if v < 0 { + v = 0 + } else if v > 10000 { + v = 10000 + } + extra["base_rpm"] = v +} diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go new file mode 100644 index 00000000..24ec5bcf --- /dev/null +++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go @@ -0,0 +1,198 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel) + router.POST("/api/v1/admin/accounts", accountHandler.Create) + router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update) + router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate) + return router +} + +func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) { + adminSvc := newStubAdminService() + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "platform": "antigravity", + "group_ids": []int64{27}, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, float64(0), resp["code"]) + data, ok := resp["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, false, data["has_risk"]) + require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID) + require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform) + require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs) +} + +func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.checkMixedErr = &service.MixedChannelError{ + GroupID: 27, + GroupName: "claude-max", + CurrentPlatform: "Antigravity", + OtherPlatform: "Anthropic", + } + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "platform": "antigravity", + "group_ids": []int64{27}, + "account_id": 99, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, float64(0), resp["code"]) + data, ok := resp["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, data["has_risk"]) + require.Equal(t, "mixed_channel_warning", data["error"]) + details, ok := data["details"].(map[string]any) + require.True(t, ok) + require.Equal(t, float64(27), details["group_id"]) + require.Equal(t, "claude-max", details["group_name"]) + require.Equal(t, "Antigravity", details["current_platform"]) + require.Equal(t, "Anthropic", details["other_platform"]) + require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID) +} + +func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.createAccountErr = &service.MixedChannelError{ + GroupID: 27, + GroupName: "claude-max", + CurrentPlatform: "Antigravity", + OtherPlatform: "Anthropic", + } + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "name": "ag-oauth-1", + "platform": "antigravity", + "type": "oauth", + "credentials": map[string]any{"refresh_token": "rt"}, + "group_ids": []int64{27}, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusConflict, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "mixed_channel_warning", resp["error"]) + require.Contains(t, resp["message"], "mixed_channel_warning") + _, hasDetails := resp["details"] + _, hasRequireConfirmation := resp["require_confirmation"] + require.False(t, hasDetails) + require.False(t, hasRequireConfirmation) +} + +func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.updateAccountErr = &service.MixedChannelError{ + GroupID: 27, + GroupName: "claude-max", + CurrentPlatform: "Antigravity", + OtherPlatform: "Anthropic", + } + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "group_ids": []int64{27}, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusConflict, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "mixed_channel_warning", resp["error"]) + require.Contains(t, resp["message"], "mixed_channel_warning") + _, hasDetails := resp["details"] + _, hasRequireConfirmation := resp["require_confirmation"] + require.False(t, hasDetails) + require.False(t, hasRequireConfirmation) +} + +func TestAccountHandlerBulkUpdateMixedChannelConflict(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.bulkUpdateAccountErr = &service.MixedChannelError{ + GroupID: 27, + GroupName: "claude-max", + CurrentPlatform: "Antigravity", + OtherPlatform: "Anthropic", + } + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1, 2, 3}, + "group_ids": []int64{27}, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusConflict, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "mixed_channel_warning", resp["error"]) + require.Contains(t, resp["message"], "claude-max") +} + +func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) { + adminSvc := newStubAdminService() + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1, 2}, + "group_ids": []int64{27}, + "confirm_mixed_channel_risk": true, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, float64(0), resp["code"]) + data, ok := resp["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, float64(2), data["success"]) + require.Equal(t, float64(0), data["failed"]) +} diff --git a/backend/internal/handler/admin/account_handler_passthrough_test.go b/backend/internal/handler/admin/account_handler_passthrough_test.go new file mode 100644 index 00000000..d86501c0 --- /dev/null +++ b/backend/internal/handler/admin/account_handler_passthrough_test.go @@ -0,0 +1,67 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testing.T) { + gin.SetMode(gin.TestMode) + + adminSvc := newStubAdminService() + handler := NewAccountHandler( + adminSvc, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + router := gin.New() + router.POST("/api/v1/admin/accounts", handler.Create) + + body := map[string]any{ + "name": "anthropic-key-1", + "platform": "anthropic", + "type": "apikey", + "credentials": map[string]any{ + "api_key": "sk-ant-xxx", + "base_url": "https://api.anthropic.com", + }, + "extra": map[string]any{ + "anthropic_passthrough": true, + }, + "concurrency": 1, + "priority": 1, + } + raw, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(raw)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Len(t, adminSvc.createdAccounts, 1) + + created := adminSvc.createdAccounts[0] + require.Equal(t, "anthropic", created.Platform) + require.Equal(t, "apikey", created.Type) + require.NotNil(t, created.Extra) + require.Equal(t, true, created.Extra["anthropic_passthrough"]) +} diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index e0f731e1..4de10d3e 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -16,10 +16,10 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router := gin.New() adminSvc := newStubAdminService() - userHandler := NewUserHandler(adminSvc) + userHandler := NewUserHandler(adminSvc, nil) groupHandler := NewGroupHandler(adminSvc) proxyHandler := NewProxyHandler(adminSvc) - redeemHandler := NewRedeemHandler(adminSvc) + redeemHandler := NewRedeemHandler(adminSvc, nil) router.GET("/api/v1/admin/users", userHandler.List) router.GET("/api/v1/admin/users/:id", userHandler.GetByID) @@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete) router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete) router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test) + router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality) router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats) router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts) @@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil) router.ServeHTTP(rec, req) diff --git a/backend/internal/handler/admin/admin_helpers_test.go b/backend/internal/handler/admin/admin_helpers_test.go index 863c755c..3833d32e 100644 --- a/backend/internal/handler/admin/admin_helpers_test.go +++ b/backend/internal/handler/admin/admin_helpers_test.go @@ -58,6 +58,96 @@ func TestParseOpsDuration(t *testing.T) { require.False(t, ok) } +func TestParseOpsOpenAITokenStatsDuration(t *testing.T) { + tests := []struct { + input string + want time.Duration + ok bool + }{ + {input: "30m", want: 30 * time.Minute, ok: true}, + {input: "1h", want: time.Hour, ok: true}, + {input: "1d", want: 24 * time.Hour, ok: true}, + {input: "15d", want: 15 * 24 * time.Hour, ok: true}, + {input: "30d", want: 30 * 24 * time.Hour, ok: true}, + {input: "7d", want: 0, ok: false}, + } + + for _, tt := range tests { + got, ok := parseOpsOpenAITokenStatsDuration(tt.input) + require.Equal(t, tt.ok, ok, "input=%s", tt.input) + require.Equal(t, tt.want, got, "input=%s", tt.input) + } +} + +func TestParseOpsOpenAITokenStatsFilter_Defaults(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + before := time.Now().UTC() + filter, err := parseOpsOpenAITokenStatsFilter(c) + after := time.Now().UTC() + + require.NoError(t, err) + require.NotNil(t, filter) + require.Equal(t, "30d", filter.TimeRange) + require.Equal(t, 1, filter.Page) + require.Equal(t, 20, filter.PageSize) + require.Equal(t, 0, filter.TopN) + require.Nil(t, filter.GroupID) + require.Equal(t, "", filter.Platform) + require.True(t, filter.StartTime.Before(filter.EndTime)) + require.WithinDuration(t, before.Add(-30*24*time.Hour), filter.StartTime, 2*time.Second) + require.WithinDuration(t, after, filter.EndTime, 2*time.Second) +} + +func TestParseOpsOpenAITokenStatsFilter_WithTopN(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest( + http.MethodGet, + "/?time_range=1h&platform=openai&group_id=12&top_n=50", + nil, + ) + + filter, err := parseOpsOpenAITokenStatsFilter(c) + require.NoError(t, err) + require.Equal(t, "1h", filter.TimeRange) + require.Equal(t, "openai", filter.Platform) + require.NotNil(t, filter.GroupID) + require.Equal(t, int64(12), *filter.GroupID) + require.Equal(t, 50, filter.TopN) + require.Equal(t, 0, filter.Page) + require.Equal(t, 0, filter.PageSize) +} + +func TestParseOpsOpenAITokenStatsFilter_InvalidParams(t *testing.T) { + tests := []string{ + "/?time_range=7d", + "/?group_id=0", + "/?group_id=abc", + "/?top_n=0", + "/?top_n=101", + "/?top_n=10&page=1", + "/?top_n=10&page_size=20", + "/?page=0", + "/?page_size=0", + "/?page_size=101", + } + + gin.SetMode(gin.TestMode) + for _, rawURL := range tests { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, rawURL, nil) + + _, err := parseOpsOpenAITokenStatsFilter(c) + require.Error(t, err, "url=%s", rawURL) + } +} + func TestParseOpsTimeRange(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 77d288f9..f3b99ddb 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -10,19 +10,28 @@ import ( ) type stubAdminService struct { - users []service.User - apiKeys []service.APIKey - groups []service.Group - accounts []service.Account - proxies []service.Proxy - proxyCounts []service.ProxyWithAccountCount - redeems []service.RedeemCode - createdAccounts []*service.CreateAccountInput - createdProxies []*service.CreateProxyInput - updatedProxyIDs []int64 - updatedProxies []*service.UpdateProxyInput - testedProxyIDs []int64 - mu sync.Mutex + users []service.User + apiKeys []service.APIKey + groups []service.Group + accounts []service.Account + proxies []service.Proxy + proxyCounts []service.ProxyWithAccountCount + redeems []service.RedeemCode + createdAccounts []*service.CreateAccountInput + createdProxies []*service.CreateProxyInput + updatedProxyIDs []int64 + updatedProxies []*service.UpdateProxyInput + testedProxyIDs []int64 + createAccountErr error + updateAccountErr error + bulkUpdateAccountErr error + checkMixedErr error + lastMixedCheck struct { + accountID int64 + platform string + groupIDs []int64 + } + mu sync.Mutex } func newStubAdminService() *stubAdminService { @@ -166,7 +175,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p return s.apiKeys, int64(len(s.apiKeys)), nil } -func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) { +func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) { return s.accounts, int64(len(s.accounts)), nil } @@ -188,11 +197,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre s.mu.Lock() s.createdAccounts = append(s.createdAccounts, input) s.mu.Unlock() + if s.createAccountErr != nil { + return nil, s.createAccountErr + } account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive} return &account, nil } func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) { + if s.updateAccountErr != nil { + return nil, s.updateAccountErr + } account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive} return &account, nil } @@ -221,7 +236,17 @@ func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, } func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) { - return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil + if s.bulkUpdateAccountErr != nil { + return nil, s.bulkUpdateAccountErr + } + return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil +} + +func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { + s.lastMixedCheck.accountID = currentAccountID + s.lastMixedCheck.platform = currentAccountPlatform + s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...) + return s.checkMixedErr } func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) { @@ -327,6 +352,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr return &service.ProxyTestResult{Success: true, Message: "ok"}, nil } +func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) { + return &service.ProxyQualityCheckResult{ + ProxyID: id, + Score: 95, + Grade: "A", + Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项", + PassedCount: 5, + WarnCount: 0, + FailedCount: 0, + ChallengeCount: 0, + CheckedAt: time.Now().Unix(), + Items: []service.ProxyQualityCheckItem{ + {Target: "base_connectivity", Status: "pass", Message: "ok"}, + {Target: "openai", Status: "pass", HTTPStatus: 401}, + {Target: "anthropic", Status: "pass", HTTPStatus: 401}, + {Target: "gemini", Status: "pass", HTTPStatus: 200}, + {Target: "sora", Status: "pass", HTTPStatus: 401}, + }, + }, nil +} + func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) { return s.redeems, int64(len(s.redeems)), nil } @@ -357,5 +403,27 @@ func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int return s.redeems, int64(len(s.redeems)), 100.0, nil } +func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error { + return nil +} + +func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) { + for i := range s.apiKeys { + if s.apiKeys[i].ID == keyID { + k := s.apiKeys[i] + if groupID != nil { + if *groupID == 0 { + k.GroupID = nil + } else { + gid := *groupID + k.GroupID = &gid + } + } + return &service.AdminUpdateAPIKeyGroupIDResult{APIKey: &k}, nil + } + } + return nil, service.ErrAPIKeyNotFound +} + // Ensure stub implements interface. var _ service.AdminService = (*stubAdminService)(nil) diff --git a/backend/internal/handler/admin/antigravity_oauth_handler.go b/backend/internal/handler/admin/antigravity_oauth_handler.go index 18541684..7488965d 100644 --- a/backend/internal/handler/admin/antigravity_oauth_handler.go +++ b/backend/internal/handler/admin/antigravity_oauth_handler.go @@ -65,3 +65,27 @@ func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) { response.Success(c, tokenInfo) } + +// AntigravityRefreshTokenRequest represents the request for validating Antigravity refresh token +type AntigravityRefreshTokenRequest struct { + RefreshToken string `json:"refresh_token" binding:"required"` + ProxyID *int64 `json:"proxy_id"` +} + +// RefreshToken validates an Antigravity refresh token and returns full token info +// POST /api/v1/admin/antigravity/oauth/refresh-token +func (h *AntigravityOAuthHandler) RefreshToken(c *gin.Context) { + var req AntigravityRefreshTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "请求无效: "+err.Error()) + return + } + + tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, tokenInfo) +} diff --git a/backend/internal/handler/admin/apikey_handler.go b/backend/internal/handler/admin/apikey_handler.go new file mode 100644 index 00000000..8dd245a4 --- /dev/null +++ b/backend/internal/handler/admin/apikey_handler.go @@ -0,0 +1,63 @@ +package admin + +import ( + "strconv" + + "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// AdminAPIKeyHandler handles admin API key management +type AdminAPIKeyHandler struct { + adminService service.AdminService +} + +// NewAdminAPIKeyHandler creates a new admin API key handler +func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandler { + return &AdminAPIKeyHandler{ + adminService: adminService, + } +} + +// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group +type AdminUpdateAPIKeyGroupRequest struct { + GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组 +} + +// UpdateGroup handles updating an API key's group binding +// PUT /api/v1/admin/api-keys/:id +func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) { + keyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid API key ID") + return + } + + var req AdminUpdateAPIKeyGroupRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + resp := struct { + APIKey *dto.APIKey `json:"api_key"` + AutoGrantedGroupAccess bool `json:"auto_granted_group_access"` + GrantedGroupID *int64 `json:"granted_group_id,omitempty"` + GrantedGroupName string `json:"granted_group_name,omitempty"` + }{ + APIKey: dto.APIKeyFromService(result.APIKey), + AutoGrantedGroupAccess: result.AutoGrantedGroupAccess, + GrantedGroupID: result.GrantedGroupID, + GrantedGroupName: result.GrantedGroupName, + } + response.Success(c, resp) +} diff --git a/backend/internal/handler/admin/apikey_handler_test.go b/backend/internal/handler/admin/apikey_handler_test.go new file mode 100644 index 00000000..bf128b18 --- /dev/null +++ b/backend/internal/handler/admin/apikey_handler_test.go @@ -0,0 +1,202 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func setupAPIKeyHandler(adminSvc service.AdminService) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + h := NewAdminAPIKeyHandler(adminSvc) + router.PUT("/api/v1/admin/api-keys/:id", h.UpdateGroup) + return router +} + +func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) { + router := setupAPIKeyHandler(newStubAdminService()) + body := `{"group_id": 2}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/abc", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "Invalid API key ID") +} + +func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) { + router := setupAPIKeyHandler(newStubAdminService()) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{bad json`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "Invalid request") +} + +func TestAdminAPIKeyHandler_UpdateGroup_KeyNotFound(t *testing.T) { + router := setupAPIKeyHandler(newStubAdminService()) + body := `{"group_id": 2}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/999", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + // ErrAPIKeyNotFound maps to 404 + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestAdminAPIKeyHandler_UpdateGroup_BindGroup(t *testing.T) { + router := setupAPIKeyHandler(newStubAdminService()) + body := `{"group_id": 2}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Code int `json:"code"` + Data json.RawMessage `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + + var data struct { + APIKey struct { + ID int64 `json:"id"` + GroupID *int64 `json:"group_id"` + } `json:"api_key"` + AutoGrantedGroupAccess bool `json:"auto_granted_group_access"` + } + require.NoError(t, json.Unmarshal(resp.Data, &data)) + require.Equal(t, int64(10), data.APIKey.ID) + require.NotNil(t, data.APIKey.GroupID) + require.Equal(t, int64(2), *data.APIKey.GroupID) +} + +func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) { + svc := newStubAdminService() + gid := int64(2) + svc.apiKeys[0].GroupID = &gid + router := setupAPIKeyHandler(svc) + body := `{"group_id": 0}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Data struct { + APIKey struct { + GroupID *int64 `json:"group_id"` + } `json:"api_key"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Nil(t, resp.Data.APIKey.GroupID) +} + +func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) { + svc := &failingUpdateGroupService{ + stubAdminService: newStubAdminService(), + err: errors.New("internal failure"), + } + router := setupAPIKeyHandler(svc) + body := `{"group_id": 2}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// H2: empty body → group_id is nil → no-op, returns original key +func TestAdminAPIKeyHandler_UpdateGroup_EmptyBody_NoChange(t *testing.T) { + router := setupAPIKeyHandler(newStubAdminService()) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + APIKey struct { + ID int64 `json:"id"` + } `json:"api_key"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, int64(10), resp.Data.APIKey.ID) +} + +// M2: service returns GROUP_NOT_ACTIVE → handler maps to 400 +func TestAdminAPIKeyHandler_UpdateGroup_GroupNotActive(t *testing.T) { + svc := &failingUpdateGroupService{ + stubAdminService: newStubAdminService(), + err: infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active"), + } + router := setupAPIKeyHandler(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": 5}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "GROUP_NOT_ACTIVE") +} + +// M2: service returns INVALID_GROUP_ID → handler maps to 400 +func TestAdminAPIKeyHandler_UpdateGroup_NegativeGroupID(t *testing.T) { + svc := &failingUpdateGroupService{ + stubAdminService: newStubAdminService(), + err: infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative"), + } + router := setupAPIKeyHandler(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": -5}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "INVALID_GROUP_ID") +} + +// failingUpdateGroupService overrides AdminUpdateAPIKeyGroupID to return an error. +type failingUpdateGroupService struct { + *stubAdminService + err error +} + +func (f *failingUpdateGroupService) AdminUpdateAPIKeyGroupID(_ context.Context, _ int64, _ *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) { + return nil, f.err +} diff --git a/backend/internal/handler/admin/batch_update_credentials_test.go b/backend/internal/handler/admin/batch_update_credentials_test.go new file mode 100644 index 00000000..0b1b6691 --- /dev/null +++ b/backend/internal/handler/admin/batch_update_credentials_test.go @@ -0,0 +1,208 @@ +//go:build unit + +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。 +type failingAdminService struct { + *stubAdminService + failOnAccountID int64 + updateCallCount atomic.Int64 +} + +func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) { + f.updateCallCount.Add(1) + if id == f.failOnAccountID { + return nil, errors.New("database error") + } + return f.stubAdminService.UpdateAccount(ctx, id, input) +} + +func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) { + gin.SetMode(gin.TestMode) + router := gin.New() + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials) + return router, handler +} + +func TestBatchUpdateCredentials_AllSuccess(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "account_uuid", + Value: "test-uuid", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200") + require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount") +} + +func TestBatchUpdateCredentials_PartialFailure(t *testing.T) { + // 让第 2 个账号(ID=2)更新时失败 + svc := &failingAdminService{ + stubAdminService: newStubAdminService(), + failOnAccountID: 2, + } + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "org_uuid", + Value: "test-org", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + // 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细 + require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细") + + var resp map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + data := resp["data"].(map[string]any) + require.Equal(t, float64(2), data["success"], "应有 2 个成功") + require.Equal(t, float64(1), data["failed"], "应有 1 个失败") + + // 所有 3 个账号都会被尝试更新(非 fail-fast) + require.Equal(t, int64(3), svc.updateCallCount.Load(), + "应调用 3 次 UpdateAccount(逐个尝试,失败后继续)") +} + +func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) { + // GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub + svc := &getAccountFailingService{ + stubAdminService: newStubAdminService(), + failOnAccountID: 1, + } + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "account_uuid", + Value: "test", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404") +} + +// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。 +type getAccountFailingService struct { + *stubAdminService + failOnAccountID int64 +} + +func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) { + if id == f.failOnAccountID { + return nil, errors.New("not found") + } + return f.stubAdminService.GetAccount(ctx, id) +} + +func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // intercept_warmup_requests 传入非 bool 类型(string),应返回 400 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "intercept_warmup_requests", + "value": "not-a-bool", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code, + "intercept_warmup_requests 传入非 bool 值应返回 400") +} + +func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "intercept_warmup_requests", + "value": true, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, + "intercept_warmup_requests 传入合法 bool 值应返回 200") +} + +func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // account_uuid 传入非 string 类型(number),应返回 400 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "account_uuid", + "value": 12345, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code, + "account_uuid 传入非 string 值应返回 400") +} + +func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // account_uuid 传入 null(设置为空),应正常通过 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "account_uuid", + "value": nil, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, + "account_uuid 传入 null 应返回 200") +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 18365186..1d48c653 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -3,6 +3,7 @@ package admin import ( "errors" "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -186,7 +187,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) { // GetUsageTrend handles getting usage trend data // GET /api/v1/admin/dashboard/trend -// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type +// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { startTime, endTime := parseTimeRange(c) granularity := c.DefaultQuery("granularity", "day") @@ -194,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 var model string + var requestType *int16 var stream *bool var billingType *int8 @@ -220,9 +222,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { if modelStr := c.Query("model"); modelStr != "" { model = modelStr } - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { if streamVal, err := strconv.ParseBool(streamStr); err == nil { stream = &streamVal + } else { + response.BadRequest(c, "Invalid stream value, use true or false") + return } } if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { @@ -235,7 +248,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { } } - trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) + trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get usage trend") return @@ -251,12 +264,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { // GetModelStats handles getting model usage statistics // GET /api/v1/admin/dashboard/models -// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type +// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type func (h *DashboardHandler) GetModelStats(c *gin.Context) { startTime, endTime := parseTimeRange(c) // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 + var requestType *int16 var stream *bool var billingType *int8 @@ -280,9 +294,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { groupID = id } } - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { if streamVal, err := strconv.ParseBool(streamStr); err == nil { stream = &streamVal + } else { + response.BadRequest(c, "Invalid stream value, use true or false") + return } } if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { @@ -295,7 +320,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { } } - stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) + stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get model statistics") return @@ -308,6 +333,76 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { }) } +// GetGroupStats handles getting group usage statistics +// GET /api/v1/admin/dashboard/groups +// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type +func (h *DashboardHandler) GetGroupStats(c *gin.Context) { + startTime, endTime := parseTimeRange(c) + + var userID, apiKeyID, accountID, groupID int64 + var requestType *int16 + var stream *bool + var billingType *int8 + + if userIDStr := c.Query("user_id"); userIDStr != "" { + if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil { + userID = id + } + } + if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" { + if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil { + apiKeyID = id + } + } + if accountIDStr := c.Query("account_id"); accountIDStr != "" { + if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil { + accountID = id + } + } + if groupIDStr := c.Query("group_id"); groupIDStr != "" { + if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil { + groupID = id + } + } + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { + if streamVal, err := strconv.ParseBool(streamStr); err == nil { + stream = &streamVal + } else { + response.BadRequest(c, "Invalid stream value, use true or false") + return + } + } + if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { + if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil { + bt := int8(v) + billingType = &bt + } else { + response.BadRequest(c, "Invalid billing_type") + return + } + } + + stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + if err != nil { + response.Error(c, 500, "Failed to get group statistics") + return + } + + response.Success(c, gin.H{ + "groups": stats, + "start_date": startTime.Format("2006-01-02"), + "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"), + }) +} + // GetAPIKeyUsageTrend handles getting API key usage trend data // GET /api/v1/admin/dashboard/api-keys-trend // Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5) @@ -379,7 +474,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { return } - stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs) + stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get user usage stats") return @@ -407,7 +502,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) { return } - stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs) + stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get API key usage stats") return diff --git a/backend/internal/handler/admin/dashboard_handler_request_type_test.go b/backend/internal/handler/admin/dashboard_handler_request_type_test.go new file mode 100644 index 00000000..72af6b45 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_handler_request_type_test.go @@ -0,0 +1,132 @@ +package admin + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type dashboardUsageRepoCapture struct { + service.UsageLogRepository + trendRequestType *int16 + trendStream *bool + modelRequestType *int16 + modelStream *bool +} + +func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + userID, apiKeyID, accountID, groupID int64, + model string, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.TrendDataPoint, error) { + s.trendRequestType = requestType + s.trendStream = stream + return []usagestats.TrendDataPoint{}, nil +} + +func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters( + ctx context.Context, + startTime, endTime time.Time, + userID, apiKeyID, accountID, groupID int64, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.ModelStat, error) { + s.modelRequestType = requestType + s.modelStream = stream + return []usagestats.ModelStat{}, nil +} + +func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) + handler := NewDashboardHandler(dashboardSvc, nil) + router := gin.New() + router.GET("/admin/dashboard/trend", handler.GetUsageTrend) + router.GET("/admin/dashboard/models", handler.GetModelStats) + return router +} + +func TestDashboardTrendRequestTypePriority(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.trendRequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType) + require.Nil(t, repo.trendStream) +} + +func TestDashboardTrendInvalidRequestType(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardTrendInvalidStream(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardModelStatsRequestTypePriority(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.modelRequestType) + require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType) + require.Nil(t, repo.modelStream) +} + +func TestDashboardModelStatsInvalidRequestType(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardModelStatsInvalidStream(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/handler/admin/data_management_handler.go b/backend/internal/handler/admin/data_management_handler.go new file mode 100644 index 00000000..02fc766f --- /dev/null +++ b/backend/internal/handler/admin/data_management_handler.go @@ -0,0 +1,545 @@ +package admin + +import ( + "context" + "strconv" + "strings" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +type DataManagementHandler struct { + dataManagementService dataManagementService +} + +func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler { + return &DataManagementHandler{dataManagementService: dataManagementService} +} + +type dataManagementService interface { + GetConfig(ctx context.Context) (service.DataManagementConfig, error) + UpdateConfig(ctx context.Context, cfg service.DataManagementConfig) (service.DataManagementConfig, error) + ValidateS3(ctx context.Context, cfg service.DataManagementS3Config) (service.DataManagementTestS3Result, error) + CreateBackupJob(ctx context.Context, input service.DataManagementCreateBackupJobInput) (service.DataManagementBackupJob, error) + ListSourceProfiles(ctx context.Context, sourceType string) ([]service.DataManagementSourceProfile, error) + CreateSourceProfile(ctx context.Context, input service.DataManagementCreateSourceProfileInput) (service.DataManagementSourceProfile, error) + UpdateSourceProfile(ctx context.Context, input service.DataManagementUpdateSourceProfileInput) (service.DataManagementSourceProfile, error) + DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error + SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (service.DataManagementSourceProfile, error) + ListS3Profiles(ctx context.Context) ([]service.DataManagementS3Profile, error) + CreateS3Profile(ctx context.Context, input service.DataManagementCreateS3ProfileInput) (service.DataManagementS3Profile, error) + UpdateS3Profile(ctx context.Context, input service.DataManagementUpdateS3ProfileInput) (service.DataManagementS3Profile, error) + DeleteS3Profile(ctx context.Context, profileID string) error + SetActiveS3Profile(ctx context.Context, profileID string) (service.DataManagementS3Profile, error) + ListBackupJobs(ctx context.Context, input service.DataManagementListBackupJobsInput) (service.DataManagementListBackupJobsResult, error) + GetBackupJob(ctx context.Context, jobID string) (service.DataManagementBackupJob, error) + EnsureAgentEnabled(ctx context.Context) error + GetAgentHealth(ctx context.Context) service.DataManagementAgentHealth +} + +type TestS3ConnectionRequest struct { + Endpoint string `json:"endpoint"` + Region string `json:"region" binding:"required"` + Bucket string `json:"bucket" binding:"required"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` +} + +type CreateBackupJobRequest struct { + BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"` + UploadToS3 bool `json:"upload_to_s3"` + S3ProfileID string `json:"s3_profile_id"` + PostgresID string `json:"postgres_profile_id"` + RedisID string `json:"redis_profile_id"` + IdempotencyKey string `json:"idempotency_key"` +} + +type CreateSourceProfileRequest struct { + ProfileID string `json:"profile_id" binding:"required"` + Name string `json:"name" binding:"required"` + Config service.DataManagementSourceConfig `json:"config" binding:"required"` + SetActive bool `json:"set_active"` +} + +type UpdateSourceProfileRequest struct { + Name string `json:"name" binding:"required"` + Config service.DataManagementSourceConfig `json:"config" binding:"required"` +} + +type CreateS3ProfileRequest struct { + ProfileID string `json:"profile_id" binding:"required"` + Name string `json:"name" binding:"required"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` + SetActive bool `json:"set_active"` +} + +type UpdateS3ProfileRequest struct { + Name string `json:"name" binding:"required"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` +} + +func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) { + health := h.getAgentHealth(c) + payload := gin.H{ + "enabled": health.Enabled, + "reason": health.Reason, + "socket_path": health.SocketPath, + } + if health.Agent != nil { + payload["agent"] = gin.H{ + "status": health.Agent.Status, + "version": health.Agent.Version, + "uptime_seconds": health.Agent.UptimeSeconds, + } + } + response.Success(c, payload) +} + +func (h *DataManagementHandler) GetConfig(c *gin.Context) { + if !h.requireAgentEnabled(c) { + return + } + cfg, err := h.dataManagementService.GetConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *DataManagementHandler) UpdateConfig(c *gin.Context) { + var req service.DataManagementConfig + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *DataManagementHandler) TestS3(c *gin.Context) { + var req TestS3ConnectionRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{ + Enabled: true, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + UseSSL: req.UseSSL, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"ok": result.OK, "message": result.Message}) +} + +func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) { + var req CreateBackupJobRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey) + if !h.requireAgentEnabled(c) { + return + } + + triggeredBy := "admin:unknown" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10) + } + job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{ + BackupType: req.BackupType, + UploadToS3: req.UploadToS3, + S3ProfileID: req.S3ProfileID, + PostgresID: req.PostgresID, + RedisID: req.RedisID, + TriggeredBy: triggeredBy, + IdempotencyKey: req.IdempotencyKey, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status}) +} + +func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType == "" { + response.BadRequest(c, "Invalid source_type") + return + } + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + + if !h.requireAgentEnabled(c) { + return + } + items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"items": items}) +} + +func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + + var req CreateSourceProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{ + SourceType: sourceType, + ProfileID: req.ProfileID, + Name: req.Name, + Config: req.Config, + SetActive: req.SetActive, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + var req UpdateSourceProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{ + SourceType: sourceType, + ProfileID: profileID, + Name: req.Name, + Config: req.Config, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) { + if !h.requireAgentEnabled(c) { + return + } + + items, err := h.dataManagementService.ListS3Profiles(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"items": items}) +} + +func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) { + var req CreateS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + + profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{ + ProfileID: req.ProfileID, + Name: req.Name, + SetActive: req.SetActive, + S3: service.DataManagementS3Config{ + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + UseSSL: req.UseSSL, + }, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) { + var req UpdateS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + + profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{ + ProfileID: profileID, + Name: req.Name, + S3: service.DataManagementS3Config{ + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + UseSSL: req.UseSSL, + }, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) { + if !h.requireAgentEnabled(c) { + return + } + + pageSize := int32(20) + if raw := strings.TrimSpace(c.Query("page_size")); raw != "" { + v, err := strconv.Atoi(raw) + if err != nil || v <= 0 { + response.BadRequest(c, "Invalid page_size") + return + } + pageSize = int32(v) + } + + result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{ + PageSize: pageSize, + PageToken: c.Query("page_token"), + Status: c.Query("status"), + BackupType: c.Query("backup_type"), + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *DataManagementHandler) GetBackupJob(c *gin.Context) { + jobID := strings.TrimSpace(c.Param("job_id")) + if jobID == "" { + response.BadRequest(c, "Invalid backup job ID") + return + } + + if !h.requireAgentEnabled(c) { + return + } + job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, job) +} + +func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool { + if h.dataManagementService == nil { + err := infraerrors.ServiceUnavailable( + service.DataManagementAgentUnavailableReason, + "data management agent service is not configured", + ).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath}) + response.ErrorFrom(c, err) + return false + } + + if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return false + } + + return true +} + +func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth { + if h.dataManagementService == nil { + return service.DataManagementAgentHealth{ + Enabled: false, + Reason: service.DataManagementAgentUnavailableReason, + SocketPath: service.DefaultDataManagementAgentSocketPath, + } + } + return h.dataManagementService.GetAgentHealth(c.Request.Context()) +} + +func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string { + headerKey := strings.TrimSpace(headerValue) + if headerKey != "" { + return headerKey + } + return strings.TrimSpace(bodyValue) +} diff --git a/backend/internal/handler/admin/data_management_handler_test.go b/backend/internal/handler/admin/data_management_handler_test.go new file mode 100644 index 00000000..ce8ee835 --- /dev/null +++ b/backend/internal/handler/admin/data_management_handler_test.go @@ -0,0 +1,78 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type apiEnvelope struct { + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason"` + Data json.RawMessage `json:"data"` +} + +func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond) + h := NewDataManagementHandler(svc) + + r := gin.New() + r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil) + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var envelope apiEnvelope + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope)) + require.Equal(t, 0, envelope.Code) + + var data struct { + Enabled bool `json:"enabled"` + Reason string `json:"reason"` + SocketPath string `json:"socket_path"` + } + require.NoError(t, json.Unmarshal(envelope.Data, &data)) + require.False(t, data.Enabled) + require.Equal(t, service.DataManagementDeprecatedReason, data.Reason) + require.Equal(t, svc.SocketPath(), data.SocketPath) +} + +func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond) + h := NewDataManagementHandler(svc) + + r := gin.New() + r.GET("/api/v1/admin/data-management/config", h.GetConfig) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil) + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + + var envelope apiEnvelope + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope)) + require.Equal(t, http.StatusServiceUnavailable, envelope.Code) + require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason) +} + +func TestNormalizeBackupIdempotencyKey(t *testing.T) { + require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body")) + require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body ")) + require.Equal(t, "", normalizeBackupIdempotencyKey("", "")) +} diff --git a/backend/internal/handler/admin/error_passthrough_handler.go b/backend/internal/handler/admin/error_passthrough_handler.go index c32db561..25aaa5c7 100644 --- a/backend/internal/handler/admin/error_passthrough_handler.go +++ b/backend/internal/handler/admin/error_passthrough_handler.go @@ -32,6 +32,7 @@ type CreateErrorPassthroughRuleRequest struct { ResponseCode *int `json:"response_code"` PassthroughBody *bool `json:"passthrough_body"` CustomMessage *string `json:"custom_message"` + SkipMonitoring *bool `json:"skip_monitoring"` Description *string `json:"description"` } @@ -48,6 +49,7 @@ type UpdateErrorPassthroughRuleRequest struct { ResponseCode *int `json:"response_code"` PassthroughBody *bool `json:"passthrough_body"` CustomMessage *string `json:"custom_message"` + SkipMonitoring *bool `json:"skip_monitoring"` Description *string `json:"description"` } @@ -122,6 +124,9 @@ func (h *ErrorPassthroughHandler) Create(c *gin.Context) { } else { rule.PassthroughBody = true } + if req.SkipMonitoring != nil { + rule.SkipMonitoring = *req.SkipMonitoring + } rule.ResponseCode = req.ResponseCode rule.CustomMessage = req.CustomMessage rule.Description = req.Description @@ -190,6 +195,7 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) { ResponseCode: existing.ResponseCode, PassthroughBody: existing.PassthroughBody, CustomMessage: existing.CustomMessage, + SkipMonitoring: existing.SkipMonitoring, Description: existing.Description, } @@ -230,6 +236,9 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) { if req.Description != nil { rule.Description = req.Description } + if req.SkipMonitoring != nil { + rule.SkipMonitoring = *req.SkipMonitoring + } // 确保切片不为 nil if rule.ErrorCodes == nil { diff --git a/backend/internal/handler/admin/gemini_oauth_handler.go b/backend/internal/handler/admin/gemini_oauth_handler.go index 50caaa26..8c398a1e 100644 --- a/backend/internal/handler/admin/gemini_oauth_handler.go +++ b/backend/internal/handler/admin/gemini_oauth_handler.go @@ -61,7 +61,11 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) { if err != nil { msg := err.Error() // Treat missing/invalid OAuth client configuration as a user/config error. - if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") { + if strings.Contains(msg, "OAuth client not configured") || + strings.Contains(msg, "requires your own OAuth Client") || + strings.Contains(msg, "requires a custom OAuth Client") || + strings.Contains(msg, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING") || + strings.Contains(msg, "built-in Gemini CLI OAuth client_secret is not configured") { response.BadRequest(c, "Failed to generate auth URL: "+msg) return } diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index d10d678b..1edf4dcc 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler { type CreateGroupRequest struct { Name string `json:"name" binding:"required"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` RateMultiplier float64 `json:"rate_multiplier"` IsExclusive bool `json:"is_exclusive"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` @@ -38,6 +38,10 @@ type CreateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` @@ -47,6 +51,8 @@ type CreateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` + // Sora 存储配额 + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -55,7 +61,7 @@ type CreateGroupRequest struct { type UpdateGroupRequest struct { Name string `json:"name"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` RateMultiplier *float64 `json:"rate_multiplier"` IsExclusive *bool `json:"is_exclusive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"` @@ -67,6 +73,10 @@ type UpdateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly *bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` @@ -76,6 +86,8 @@ type UpdateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string `json:"supported_model_scopes"` + // Sora 存储配额 + SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -179,6 +191,10 @@ func (h *GroupHandler) Create(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, @@ -186,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -225,6 +242,10 @@ func (h *GroupHandler) Update(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, @@ -232,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -302,3 +324,36 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) { } response.Paginated(c, outKeys, total, page, pageSize) } + +// UpdateSortOrderRequest represents the request to update group sort orders +type UpdateSortOrderRequest struct { + Updates []struct { + ID int64 `json:"id" binding:"required"` + SortOrder int `json:"sort_order"` + } `json:"updates" binding:"required,min=1"` +} + +// UpdateSortOrder handles updating group sort orders +// PUT /api/v1/admin/groups/sort-order +func (h *GroupHandler) UpdateSortOrder(c *gin.Context) { + var req UpdateSortOrderRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + updates := make([]service.GroupSortOrderUpdate, 0, len(req.Updates)) + for _, u := range req.Updates { + updates = append(updates, service.GroupSortOrderUpdate{ + ID: u.ID, + SortOrder: u.SortOrder, + }) + } + + if err := h.adminService.UpdateGroupSortOrders(c.Request.Context(), updates); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Sort order updated successfully"}) +} diff --git a/backend/internal/handler/admin/idempotency_helper.go b/backend/internal/handler/admin/idempotency_helper.go new file mode 100644 index 00000000..aa8eeaaf --- /dev/null +++ b/backend/internal/handler/admin/idempotency_helper.go @@ -0,0 +1,115 @@ +package admin + +import ( + "context" + "strconv" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +type idempotencyStoreUnavailableMode int + +const ( + idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota + idempotencyStoreUnavailableFailOpen +) + +func executeAdminIdempotent( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) (*service.IdempotencyExecuteResult, error) { + coordinator := service.DefaultIdempotencyCoordinator() + if coordinator == nil { + data, err := execute(c.Request.Context()) + if err != nil { + return nil, err + } + return &service.IdempotencyExecuteResult{Data: data}, nil + } + + actorScope := "admin:0" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10) + } + + return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{ + Scope: scope, + ActorScope: actorScope, + Method: c.Request.Method, + Route: c.FullPath(), + IdempotencyKey: c.GetHeader("Idempotency-Key"), + Payload: payload, + RequireKey: true, + TTL: ttl, + }, execute) +} + +func executeAdminIdempotentJSON( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) { + executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute) +} + +func executeAdminIdempotentJSONFailOpenOnStoreUnavailable( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) { + executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute) +} + +func executeAdminIdempotentJSONWithMode( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + mode idempotencyStoreUnavailableMode, + execute func(context.Context) (any, error), +) { + result, err := executeAdminIdempotent(c, scope, payload, ttl, execute) + if err != nil { + if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) { + strategy := "fail_close" + if mode == idempotencyStoreUnavailableFailOpen { + strategy = "fail_open" + } + service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy) + logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy) + if mode == idempotencyStoreUnavailableFailOpen { + data, fallbackErr := execute(c.Request.Context()) + if fallbackErr != nil { + response.ErrorFrom(c, fallbackErr) + return + } + c.Header("X-Idempotency-Degraded", "store-unavailable") + response.Success(c, data) + return + } + } + if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } + response.ErrorFrom(c, err) + return + } + if result != nil && result.Replayed { + c.Header("X-Idempotency-Replayed", "true") + } + response.Success(c, result.Data) +} diff --git a/backend/internal/handler/admin/idempotency_helper_test.go b/backend/internal/handler/admin/idempotency_helper_test.go new file mode 100644 index 00000000..7dd86e16 --- /dev/null +++ b/backend/internal/handler/admin/idempotency_helper_test.go @@ -0,0 +1,285 @@ +package admin + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type storeUnavailableRepoStub struct{} + +func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) { + return false, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) { + return nil, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return errors.New("store unavailable") +} +func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return errors.New("store unavailable") +} +func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, errors.New("store unavailable") +} + +func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig())) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed int + router := gin.New() + router.POST("/idempotent", func(c *gin.Context) { + executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "test-key-1") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable") +} + +func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig())) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed int + router := gin.New() + router.POST("/idempotent", func(c *gin.Context) { + executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "test-key-2") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded")) + require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue") +} + +type memoryIdempotencyRepoStub struct { + mu sync.Mutex + nextID int64 + data map[string]*service.IdempotencyRecord +} + +func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub { + return &memoryIdempotencyRepoStub{ + nextID: 1, + data: make(map[string]*service.IdempotencyRecord), + } +} + +func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string { + return scope + "|" + keyHash +} + +func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord { + if in == nil { + return nil + } + out := *in + if in.LockedUntil != nil { + v := *in.LockedUntil + out.LockedUntil = &v + } + if in.ResponseBody != nil { + v := *in.ResponseBody + out.ResponseBody = &v + } + if in.ResponseStatus != nil { + v := *in.ResponseStatus + out.ResponseStatus = &v + } + if in.ErrorReason != nil { + v := *in.ErrorReason + out.ErrorReason = &v + } + return &out +} + +func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + k := r.key(record.Scope, record.IdempotencyKeyHash) + if _, ok := r.data[k]; ok { + return false, nil + } + cp := r.clone(record) + cp.ID = r.nextID + r.nextID++ + r.data[k] = cp + record.ID = cp.ID + return true, nil +} + +func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) { + r.mu.Lock() + defer r.mu.Unlock() + return r.clone(r.data[r.key(scope, keyHash)]), nil +} + +func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != fromStatus { + return false, nil + } + if rec.LockedUntil != nil && rec.LockedUntil.After(now) { + return false, nil + } + rec.Status = service.IdempotencyStatusProcessing + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.ErrorReason = nil + return true, nil + } + return false, nil +} + +func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint { + return false, nil + } + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + return true, nil + } + return false, nil +} + +func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusSucceeded + rec.LockedUntil = nil + rec.ExpiresAt = expiresAt + rec.ResponseStatus = &responseStatus + rec.ResponseBody = &responseBody + rec.ErrorReason = nil + return nil + } + return nil +} + +func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusFailedRetryable + rec.LockedUntil = &lockedUntil + rec.ExpiresAt = expiresAt + rec.ErrorReason = &errorReason + return nil + } + return nil +} + +func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) { + return 0, nil +} + +func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := newMemoryIdempotencyRepoStub() + cfg := service.DefaultIdempotencyConfig() + cfg.ProcessingTimeout = 2 * time.Second + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg)) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed atomic.Int32 + router := gin.New() + router.POST("/idempotent", func(c *gin.Context) { + executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed.Add(1) + time.Sleep(120 * time.Millisecond) + return gin.H{"ok": true}, nil + }) + }) + + call := func() (int, http.Header) { + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "same-key") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + return rec.Code, rec.Header() + } + + var status1, status2 int + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + status1, _ = call() + }() + go func() { + defer wg.Done() + status2, _ = call() + }() + wg.Wait() + + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1) + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2) + require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once") + + status3, headers3 := call() + require.Equal(t, http.StatusOK, status3) + require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed")) + require.Equal(t, int32(1), executed.Load()) +} diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go index ed86fea9..5d354fd3 100644 --- a/backend/internal/handler/admin/openai_oauth_handler.go +++ b/backend/internal/handler/admin/openai_oauth_handler.go @@ -2,8 +2,10 @@ package admin import ( "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -16,6 +18,13 @@ type OpenAIOAuthHandler struct { adminService service.AdminService } +func oauthPlatformFromPath(c *gin.Context) string { + if strings.Contains(c.FullPath(), "/admin/sora/") { + return service.PlatformSora + } + return service.PlatformOpenAI +} + // NewOpenAIOAuthHandler creates a new OpenAI OAuth handler func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler { return &OpenAIOAuthHandler{ @@ -39,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { req = OpenAIGenerateAuthURLRequest{} } - result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI) + result, err := h.openaiOAuthService.GenerateAuthURL( + c.Request.Context(), + req.ProxyID, + req.RedirectURI, + oauthPlatformFromPath(c), + ) if err != nil { response.ErrorFrom(c, err) return @@ -52,6 +66,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { type OpenAIExchangeCodeRequest struct { SessionID string `json:"session_id" binding:"required"` Code string `json:"code" binding:"required"` + State string `json:"state" binding:"required"` RedirectURI string `json:"redirect_uri"` ProxyID *int64 `json:"proxy_id"` } @@ -68,6 +83,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ SessionID: req.SessionID, Code: req.Code, + State: req.State, RedirectURI: req.RedirectURI, ProxyID: req.ProxyID, }) @@ -81,18 +97,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { // OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token type OpenAIRefreshTokenRequest struct { - RefreshToken string `json:"refresh_token" binding:"required"` + RefreshToken string `json:"refresh_token"` + RT string `json:"rt"` + ClientID string `json:"client_id"` ProxyID *int64 `json:"proxy_id"` } // RefreshToken refreshes an OpenAI OAuth token // POST /api/v1/admin/openai/refresh-token +// POST /api/v1/admin/sora/rt2at func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { var req OpenAIRefreshTokenRequest if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, "Invalid request: "+err.Error()) return } + refreshToken := strings.TrimSpace(req.RefreshToken) + if refreshToken == "" { + refreshToken = strings.TrimSpace(req.RT) + } + if refreshToken == "" { + response.BadRequest(c, "refresh_token is required") + return + } var proxyURL string if req.ProxyID != nil { @@ -102,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { } } - tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) + // 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜 + clientID := strings.TrimSpace(req.ClientID) + if clientID == "" { + platform := oauthPlatformFromPath(c) + clientID, _ = openai.OAuthClientConfigByPlatform(platform) + } + + tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID) if err != nil { response.ErrorFrom(c, err) return @@ -111,8 +145,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { response.Success(c, tokenInfo) } -// RefreshAccountToken refreshes token for a specific OpenAI account +// ExchangeSoraSessionToken exchanges Sora session token to access token +// POST /api/v1/admin/sora/st2at +func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) { + var req struct { + SessionToken string `json:"session_token"` + ST string `json:"st"` + ProxyID *int64 `json:"proxy_id"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + sessionToken := strings.TrimSpace(req.SessionToken) + if sessionToken == "" { + sessionToken = strings.TrimSpace(req.ST) + } + if sessionToken == "" { + response.BadRequest(c, "session_token is required") + return + } + + tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, tokenInfo) +} + +// RefreshAccountToken refreshes token for a specific OpenAI/Sora account // POST /api/v1/admin/openai/accounts/:id/refresh +// POST /api/v1/admin/sora/accounts/:id/refresh func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { @@ -127,9 +192,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { return } - // Ensure account is OpenAI platform - if !account.IsOpenAI() { - response.BadRequest(c, "Account is not an OpenAI account") + platform := oauthPlatformFromPath(c) + if account.Platform != platform { + response.BadRequest(c, "Account platform does not match OAuth endpoint") return } @@ -167,12 +232,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { response.Success(c, dto.AccountFromService(updatedAccount)) } -// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info +// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info // POST /api/v1/admin/openai/create-from-oauth +// POST /api/v1/admin/sora/create-from-oauth func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { var req struct { SessionID string `json:"session_id" binding:"required"` Code string `json:"code" binding:"required"` + State string `json:"state" binding:"required"` RedirectURI string `json:"redirect_uri"` ProxyID *int64 `json:"proxy_id"` Name string `json:"name"` @@ -189,6 +256,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ SessionID: req.SessionID, Code: req.Code, + State: req.State, RedirectURI: req.RedirectURI, ProxyID: req.ProxyID, }) @@ -200,19 +268,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { // Build credentials from token info credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo) + platform := oauthPlatformFromPath(c) + // Use email as default name if not provided name := req.Name if name == "" && tokenInfo.Email != "" { name = tokenInfo.Email } if name == "" { - name = "OpenAI OAuth Account" + if platform == service.PlatformSora { + name = "Sora OAuth Account" + } else { + name = "OpenAI OAuth Account" + } } // Create account account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ Name: name, - Platform: "openai", + Platform: platform, Type: "oauth", Credentials: credentials, ProxyID: req.ProxyID, diff --git a/backend/internal/handler/admin/ops_dashboard_handler.go b/backend/internal/handler/admin/ops_dashboard_handler.go index 2c87f734..01f7bc2b 100644 --- a/backend/internal/handler/admin/ops_dashboard_handler.go +++ b/backend/internal/handler/admin/ops_dashboard_handler.go @@ -1,6 +1,7 @@ package admin import ( + "fmt" "net/http" "strconv" "strings" @@ -218,6 +219,115 @@ func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) { response.Success(c, data) } +// GetDashboardOpenAITokenStats returns OpenAI token efficiency stats grouped by model. +// GET /api/v1/admin/ops/dashboard/openai-token-stats +func (h *OpsHandler) GetDashboardOpenAITokenStats(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + filter, err := parseOpsOpenAITokenStatsFilter(c) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + data, err := h.opsService.GetOpenAITokenStats(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, data) +} + +func parseOpsOpenAITokenStatsFilter(c *gin.Context) (*service.OpsOpenAITokenStatsFilter, error) { + if c == nil { + return nil, fmt.Errorf("invalid request") + } + + timeRange := strings.TrimSpace(c.Query("time_range")) + if timeRange == "" { + timeRange = "30d" + } + dur, ok := parseOpsOpenAITokenStatsDuration(timeRange) + if !ok { + return nil, fmt.Errorf("invalid time_range") + } + end := time.Now().UTC() + start := end.Add(-dur) + + filter := &service.OpsOpenAITokenStatsFilter{ + TimeRange: timeRange, + StartTime: start, + EndTime: end, + Platform: strings.TrimSpace(c.Query("platform")), + } + + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + return nil, fmt.Errorf("invalid group_id") + } + filter.GroupID = &id + } + + topNRaw := strings.TrimSpace(c.Query("top_n")) + pageRaw := strings.TrimSpace(c.Query("page")) + pageSizeRaw := strings.TrimSpace(c.Query("page_size")) + if topNRaw != "" && (pageRaw != "" || pageSizeRaw != "") { + return nil, fmt.Errorf("invalid query: top_n cannot be used with page/page_size") + } + + if topNRaw != "" { + topN, err := strconv.Atoi(topNRaw) + if err != nil || topN < 1 || topN > 100 { + return nil, fmt.Errorf("invalid top_n") + } + filter.TopN = topN + return filter, nil + } + + filter.Page = 1 + filter.PageSize = 20 + if pageRaw != "" { + page, err := strconv.Atoi(pageRaw) + if err != nil || page < 1 { + return nil, fmt.Errorf("invalid page") + } + filter.Page = page + } + if pageSizeRaw != "" { + pageSize, err := strconv.Atoi(pageSizeRaw) + if err != nil || pageSize < 1 || pageSize > 100 { + return nil, fmt.Errorf("invalid page_size") + } + filter.PageSize = pageSize + } + return filter, nil +} + +func parseOpsOpenAITokenStatsDuration(v string) (time.Duration, bool) { + switch strings.TrimSpace(v) { + case "30m": + return 30 * time.Minute, true + case "1h": + return time.Hour, true + case "1d": + return 24 * time.Hour, true + case "15d": + return 15 * 24 * time.Hour, true + case "30d": + return 30 * 24 * time.Hour, true + default: + return 0, false + } +} + func pickThroughputBucketSeconds(window time.Duration) int { // Keep buckets predictable and avoid huge responses. switch { diff --git a/backend/internal/handler/admin/ops_runtime_logging_handler_test.go b/backend/internal/handler/admin/ops_runtime_logging_handler_test.go new file mode 100644 index 00000000..0e84b4f9 --- /dev/null +++ b/backend/internal/handler/admin/ops_runtime_logging_handler_test.go @@ -0,0 +1,173 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type testSettingRepo struct { + values map[string]string +} + +func newTestSettingRepo() *testSettingRepo { + return &testSettingRepo{values: map[string]string{}} +} + +func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) { + v, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &service.Setting{Key: key, Value: v}, nil +} +func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) { + v, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return v, nil +} +func (s *testSettingRepo) Set(ctx context.Context, key, value string) error { + s.values[key] = value + return nil +} +func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, k := range keys { + if v, ok := s.values[k]; ok { + out[k] = v + } + } + return out, nil +} +func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error { + for k, v := range settings { + s.values[k] = v + } + return nil +} +func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for k, v := range s.values { + out[k] = v + } + return out, nil +} +func (s *testSettingRepo) Delete(ctx context.Context, key string) error { + delete(s.values, key) + return nil +} + +func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + if withUser { + r.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7}) + c.Next() + }) + } + r.GET("/runtime/logging", handler.GetRuntimeLogConfig) + r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig) + r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig) + return r +} + +func newRuntimeOpsService(t *testing.T) *service.OpsService { + t.Helper() + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: false, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + settingRepo := newTestSettingRepo() + cfg := &config.Config{ + Ops: config.OpsConfig{Enabled: true}, + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + } + return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil) +} + +func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) { + h := NewOpsHandler(newRuntimeOpsService(t)) + r := newOpsRuntimeRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d, want 200", w.Code) + } +} + +func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) { + h := NewOpsHandler(newRuntimeOpsService(t)) + r := newOpsRuntimeRouter(h, false) + + body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}` + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Fatalf("status=%d, want 401", w.Code) + } +} + +func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) { + h := NewOpsHandler(newRuntimeOpsService(t)) + r := newOpsRuntimeRouter(h, true) + + payload := map[string]any{ + "level": "debug", + "enable_sampling": false, + "sampling_initial": 100, + "sampling_thereafter": 100, + "caller": true, + "stacktrace_level": "error", + "retention_days": 30, + } + raw, _ := json.Marshal(payload) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String()) + } + + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String()) + } +} diff --git a/backend/internal/handler/admin/ops_settings_handler.go b/backend/internal/handler/admin/ops_settings_handler.go index ebc8bf49..226b89f3 100644 --- a/backend/internal/handler/admin/ops_settings_handler.go +++ b/backend/internal/handler/admin/ops_settings_handler.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) @@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) { response.Success(c, updated) } +// GetRuntimeLogConfig returns runtime log config (DB-backed). +// GET /api/v1/admin/ops/runtime/logging +func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context()) + if err != nil { + response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config") + return + } + response.Success(c, cfg) +} + +// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately. +// PUT /api/v1/admin/ops/runtime/logging +func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + var req service.OpsRuntimeLogConfig + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Error(c, http.StatusUnauthorized, "Unauthorized") + return + } + + updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID) + if err != nil { + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + response.Success(c, updated) +} + +// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline. +// POST /api/v1/admin/ops/runtime/logging/reset +func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Error(c, http.StatusUnauthorized, "Unauthorized") + return + } + + updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID) + if err != nil { + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + response.Success(c, updated) +} + // GetAdvancedSettings returns Ops advanced settings (DB-backed). // GET /api/v1/admin/ops/advanced-settings func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) { diff --git a/backend/internal/handler/admin/ops_system_log_handler.go b/backend/internal/handler/admin/ops_system_log_handler.go new file mode 100644 index 00000000..31fd51eb --- /dev/null +++ b/backend/internal/handler/admin/ops_system_log_handler.go @@ -0,0 +1,174 @@ +package admin + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type opsSystemLogCleanupRequest struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + + Level string `json:"level"` + Component string `json:"component"` + RequestID string `json:"request_id"` + ClientRequestID string `json:"client_request_id"` + UserID *int64 `json:"user_id"` + AccountID *int64 `json:"account_id"` + Platform string `json:"platform"` + Model string `json:"model"` + Query string `json:"q"` +} + +// ListSystemLogs returns indexed system logs. +// GET /api/v1/admin/ops/system-logs +func (h *OpsHandler) ListSystemLogs(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + page, pageSize := response.ParsePagination(c) + if pageSize > 200 { + pageSize = 200 + } + + start, end, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsSystemLogFilter{ + Page: page, + PageSize: pageSize, + StartTime: &start, + EndTime: &end, + Level: strings.TrimSpace(c.Query("level")), + Component: strings.TrimSpace(c.Query("component")), + RequestID: strings.TrimSpace(c.Query("request_id")), + ClientRequestID: strings.TrimSpace(c.Query("client_request_id")), + Platform: strings.TrimSpace(c.Query("platform")), + Model: strings.TrimSpace(c.Query("model")), + Query: strings.TrimSpace(c.Query("q")), + } + if v := strings.TrimSpace(c.Query("user_id")); v != "" { + id, parseErr := strconv.ParseInt(v, 10, 64) + if parseErr != nil || id <= 0 { + response.BadRequest(c, "Invalid user_id") + return + } + filter.UserID = &id + } + if v := strings.TrimSpace(c.Query("account_id")); v != "" { + id, parseErr := strconv.ParseInt(v, 10, 64) + if parseErr != nil || id <= 0 { + response.BadRequest(c, "Invalid account_id") + return + } + filter.AccountID = &id + } + + result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize) +} + +// CleanupSystemLogs deletes indexed system logs by filter. +// POST /api/v1/admin/ops/system-logs/cleanup +func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Error(c, http.StatusUnauthorized, "Unauthorized") + return + } + + var req opsSystemLogCleanupRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + + parseTS := func(raw string) (*time.Time, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, nil + } + if t, err := time.Parse(time.RFC3339Nano, raw); err == nil { + return &t, nil + } + t, err := time.Parse(time.RFC3339, raw) + if err != nil { + return nil, err + } + return &t, nil + } + start, err := parseTS(req.StartTime) + if err != nil { + response.BadRequest(c, "Invalid start_time") + return + } + end, err := parseTS(req.EndTime) + if err != nil { + response.BadRequest(c, "Invalid end_time") + return + } + + filter := &service.OpsSystemLogCleanupFilter{ + StartTime: start, + EndTime: end, + Level: strings.TrimSpace(req.Level), + Component: strings.TrimSpace(req.Component), + RequestID: strings.TrimSpace(req.RequestID), + ClientRequestID: strings.TrimSpace(req.ClientRequestID), + UserID: req.UserID, + AccountID: req.AccountID, + Platform: strings.TrimSpace(req.Platform), + Model: strings.TrimSpace(req.Model), + Query: strings.TrimSpace(req.Query), + } + + deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": deleted}) +} + +// GetSystemLogIngestionHealth returns sink health metrics. +// GET /api/v1/admin/ops/system-logs/health +func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, h.opsService.GetSystemLogSinkHealth()) +} diff --git a/backend/internal/handler/admin/ops_system_log_handler_test.go b/backend/internal/handler/admin/ops_system_log_handler_test.go new file mode 100644 index 00000000..7528acd8 --- /dev/null +++ b/backend/internal/handler/admin/ops_system_log_handler_test.go @@ -0,0 +1,233 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type responseEnvelope struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` +} + +func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + if withUser { + r.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99}) + c.Next() + }) + } + r.GET("/logs", handler.ListSystemLogs) + r.POST("/logs/cleanup", handler.CleanupSystemLogs) + r.GET("/logs/health", handler.GetSystemLogIngestionHealth) + return r +} + +func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) { + h := NewOpsHandler(nil) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("status=%d, want 503", w.Code) + } +} + +func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) { + svc := service.NewOpsService(nil, nil, &config.Config{ + Ops: config.OpsConfig{Enabled: false}, + }, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("status=%d, want 404", w.Code) + } +} + +func TestOpsSystemLogHandler_ListSuccess(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d, want 200", w.Code) + } + + var resp responseEnvelope + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.Code != 0 { + t.Fatalf("unexpected response code: %+v", resp) + } +} + +func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Fatalf("status=%d, want 401", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("status=%d, want 503", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) { + svc := service.NewOpsService(nil, nil, &config.Config{ + Ops: config.OpsConfig{Enabled: false}, + }, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("status=%d, want 404", w.Code) + } +} + +func TestOpsSystemLogHandler_Health(t *testing.T) { + sink := service.NewOpsSystemLogSink(nil) + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs/health", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d, want 200", w.Code) + } +} + +func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) { + h := NewOpsHandler(nil) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs/health", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("status=%d, want 503", w.Code) + } + + svc := service.NewOpsService(nil, nil, &config.Config{ + Ops: config.OpsConfig{Enabled: false}, + }, nil, nil, nil, nil, nil, nil, nil, nil) + h = NewOpsHandler(svc) + r = newOpsSystemLogTestRouter(h, false) + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/logs/health", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("status=%d, want 404", w.Code) + } +} diff --git a/backend/internal/handler/admin/ops_ws_handler.go b/backend/internal/handler/admin/ops_ws_handler.go index db7442e5..75fd7ea0 100644 --- a/backend/internal/handler/admin/ops_ws_handler.go +++ b/backend/internal/handler/admin/ops_ws_handler.go @@ -3,7 +3,6 @@ package admin import ( "context" "encoding/json" - "log" "math" "net" "net/http" @@ -16,6 +15,7 @@ import ( "sync/atomic" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -62,7 +62,8 @@ const ( ) var wsConnCount atomic.Int32 -var wsConnCountByIP sync.Map // map[string]*atomic.Int32 +var wsConnCountByIPMu sync.Mutex +var wsConnCountByIP = make(map[string]int32) const qpsWSIdleStopDelay = 30 * time.Second @@ -252,7 +253,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) { stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now) if err != nil || stats == nil { if err != nil { - log.Printf("[OpsWS] refresh: get window stats failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err) } return } @@ -278,7 +279,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) { msg, err := json.Marshal(payload) if err != nil { - log.Printf("[OpsWS] refresh: marshal payload failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err) return } @@ -338,7 +339,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { // Reserve a global slot before upgrading the connection to keep the limit strict. if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) { - log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns) c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) return } @@ -350,7 +351,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" { if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) { - log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP) c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) return } @@ -359,7 +360,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - log.Printf("[OpsWS] upgrade failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err) return } @@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool { if strings.TrimSpace(clientIP) == "" || limit <= 0 { return true } - - v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{}) - counter, ok := v.(*atomic.Int32) - if !ok { + wsConnCountByIPMu.Lock() + defer wsConnCountByIPMu.Unlock() + current := wsConnCountByIP[clientIP] + if current >= limit { return false } - - for { - current := counter.Load() - if current >= limit { - return false - } - if counter.CompareAndSwap(current, current+1) { - return true - } - } + wsConnCountByIP[clientIP] = current + 1 + return true } func releaseOpsWSIPSlot(clientIP string) { if strings.TrimSpace(clientIP) == "" { return } - - v, ok := wsConnCountByIP.Load(clientIP) + wsConnCountByIPMu.Lock() + defer wsConnCountByIPMu.Unlock() + current, ok := wsConnCountByIP[clientIP] if !ok { return } - counter, ok := v.(*atomic.Int32) - if !ok { + if current <= 1 { + delete(wsConnCountByIP, clientIP) return } - next := counter.Add(-1) - if next <= 0 { - // Best-effort cleanup; safe even if a new slot was acquired concurrently. - wsConnCountByIP.Delete(clientIP) - } + wsConnCountByIP[clientIP] = current - 1 } func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { @@ -452,7 +442,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { conn.SetReadLimit(qpsWSMaxReadBytes) if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil { - log.Printf("[OpsWS] set read deadline failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err) return } conn.SetPongHandler(func(string) error { @@ -471,7 +461,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { _, _, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { - log.Printf("[OpsWS] read failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err) } return } @@ -508,7 +498,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { continue } if err := writeWithTimeout(websocket.TextMessage, msg); err != nil { - log.Printf("[OpsWS] write failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err) cancel() closeConn() wg.Wait() @@ -517,7 +507,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { case <-pingTicker.C: if err := writeWithTimeout(websocket.PingMessage, nil); err != nil { - log.Printf("[OpsWS] ping failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err) cancel() closeConn() wg.Wait() @@ -666,14 +656,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig { if parsed, err := strconv.ParseBool(v); err == nil { cfg.TrustProxy = parsed } else { - log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy) } } if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" { prefixes, invalid := parseTrustedProxyList(raw) if len(invalid) > 0 { - log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", ")) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", ")) } cfg.TrustedProxies = prefixes } @@ -684,7 +674,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig { case OriginPolicyStrict, OriginPolicyPermissive: cfg.OriginPolicy = normalized default: - log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy) } } @@ -701,14 +691,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits { if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 { cfg.MaxConns = int32(parsed) } else { - log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns) } } if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" { if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 { cfg.MaxConnsPerIP = int32(parsed) } else { - log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP) } } return cfg diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index a6758f69..e8ae0ce2 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "strconv" "strings" @@ -63,9 +64,9 @@ func (h *ProxyHandler) List(c *gin.Context) { return } - out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) + out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies)) for i := range proxies { - out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) + out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i])) } response.Paginated(c, out, total, page, pageSize) } @@ -82,9 +83,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) { response.ErrorFrom(c, err) return } - out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) + out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies)) for i := range proxies { - out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) + out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i])) } response.Success(c, out) return @@ -96,9 +97,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) { return } - out := make([]dto.Proxy, 0, len(proxies)) + out := make([]dto.AdminProxy, 0, len(proxies)) for i := range proxies { - out = append(out, *dto.ProxyFromService(&proxies[i])) + out = append(out, *dto.ProxyFromServiceAdmin(&proxies[i])) } response.Success(c, out) } @@ -118,7 +119,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) { return } - response.Success(c, dto.ProxyFromService(proxy)) + response.Success(c, dto.ProxyFromServiceAdmin(proxy)) } // Create handles creating a new proxy @@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) { return } - proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ - Name: strings.TrimSpace(req.Name), - Protocol: strings.TrimSpace(req.Protocol), - Host: strings.TrimSpace(req.Host), - Port: req.Port, - Username: strings.TrimSpace(req.Username), - Password: strings.TrimSpace(req.Password), + executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{ + Name: strings.TrimSpace(req.Name), + Protocol: strings.TrimSpace(req.Protocol), + Host: strings.TrimSpace(req.Host), + Port: req.Port, + Username: strings.TrimSpace(req.Username), + Password: strings.TrimSpace(req.Password), + }) + if err != nil { + return nil, err + } + return dto.ProxyFromServiceAdmin(proxy), nil }) - if err != nil { - response.ErrorFrom(c, err) - return - } - - response.Success(c, dto.ProxyFromService(proxy)) } // Update handles updating a proxy @@ -175,7 +176,7 @@ func (h *ProxyHandler) Update(c *gin.Context) { return } - response.Success(c, dto.ProxyFromService(proxy)) + response.Success(c, dto.ProxyFromServiceAdmin(proxy)) } // Delete handles deleting a proxy @@ -236,6 +237,24 @@ func (h *ProxyHandler) Test(c *gin.Context) { response.Success(c, result) } +// CheckQuality handles checking proxy quality across common AI targets. +// POST /api/v1/admin/proxies/:id/quality-check +func (h *ProxyHandler) CheckQuality(c *gin.Context) { + proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid proxy ID") + return + } + + result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, result) +} + // GetStats handles getting proxy statistics // GET /api/v1/admin/proxies/:id/stats func (h *ProxyHandler) GetStats(c *gin.Context) { diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index e229385f..0a932ee9 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -2,12 +2,15 @@ package admin import ( "bytes" + "context" "encoding/csv" + "errors" "fmt" "strconv" "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -16,13 +19,15 @@ import ( // RedeemHandler handles admin redeem code management type RedeemHandler struct { - adminService service.AdminService + adminService service.AdminService + redeemService *service.RedeemService } // NewRedeemHandler creates a new admin redeem handler -func NewRedeemHandler(adminService service.AdminService) *RedeemHandler { +func NewRedeemHandler(adminService service.AdminService, redeemService *service.RedeemService) *RedeemHandler { return &RedeemHandler{ - adminService: adminService, + adminService: adminService, + redeemService: redeemService, } } @@ -35,6 +40,15 @@ type GenerateRedeemCodesRequest struct { ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年 } +// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user. +type CreateAndRedeemCodeRequest struct { + Code string `json:"code" binding:"required,min=3,max=128"` + Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"` + Value float64 `json:"value" binding:"required,gt=0"` + UserID int64 `json:"user_id" binding:"required,gt=0"` + Notes string `json:"notes"` +} + // List handles listing all redeem codes with pagination // GET /api/v1/admin/redeem-codes func (h *RedeemHandler) List(c *gin.Context) { @@ -88,23 +102,99 @@ func (h *RedeemHandler) Generate(c *gin.Context) { return } - codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{ - Count: req.Count, - Type: req.Type, - Value: req.Value, - GroupID: req.GroupID, - ValidityDays: req.ValidityDays, + executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{ + Count: req.Count, + Type: req.Type, + Value: req.Value, + GroupID: req.GroupID, + ValidityDays: req.ValidityDays, + }) + if execErr != nil { + return nil, execErr + } + + out := make([]dto.AdminRedeemCode, 0, len(codes)) + for i := range codes { + out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) + } + return out, nil }) - if err != nil { - response.ErrorFrom(c, err) +} + +// CreateAndRedeem creates a fixed redeem code and redeems it for a target user in one step. +// POST /api/v1/admin/redeem-codes/create-and-redeem +func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) { + if h.redeemService == nil { + response.InternalError(c, "redeem service not configured") return } - out := make([]dto.AdminRedeemCode, 0, len(codes)) - for i := range codes { - out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) + var req CreateAndRedeemCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return } - response.Success(c, out) + req.Code = strings.TrimSpace(req.Code) + + executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + existing, err := h.redeemService.GetByCode(ctx, req.Code) + if err == nil { + return h.resolveCreateAndRedeemExisting(ctx, existing, req.UserID) + } + if !errors.Is(err, service.ErrRedeemCodeNotFound) { + return nil, err + } + + createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{ + Code: req.Code, + Type: req.Type, + Value: req.Value, + Status: service.StatusUnused, + Notes: req.Notes, + }) + if createErr != nil { + // Unique code race: if code now exists, use idempotent semantics by used_by. + existingAfterCreateErr, getErr := h.redeemService.GetByCode(ctx, req.Code) + if getErr == nil { + return h.resolveCreateAndRedeemExisting(ctx, existingAfterCreateErr, req.UserID) + } + return nil, createErr + } + + redeemed, redeemErr := h.redeemService.Redeem(ctx, req.UserID, req.Code) + if redeemErr != nil { + return nil, redeemErr + } + return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil + }) +} + +func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, existing *service.RedeemCode, userID int64) (any, error) { + if existing == nil { + return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code conflict") + } + + // If previous run created the code but crashed before redeem, redeem it now. + if existing.CanUse() { + redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code) + if err == nil { + return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil + } + if !errors.Is(err, service.ErrRedeemCodeUsed) { + return nil, err + } + latest, getErr := h.redeemService.GetByCode(ctx, existing.Code) + if getErr == nil { + existing = latest + } + } + + if existing.UsedBy != nil && *existing.UsedBy == userID { + return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(existing)}, nil + } + + return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code already used by another user") } // Delete handles deleting a redeem code @@ -202,7 +292,7 @@ func (h *RedeemHandler) Export(c *gin.Context) { writer := csv.NewWriter(&buf) // Write header - if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil { + if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil { response.InternalError(c, "Failed to export redeem codes: "+err.Error()) return } @@ -213,6 +303,10 @@ func (h *RedeemHandler) Export(c *gin.Context) { if code.UsedBy != nil { usedBy = fmt.Sprintf("%d", *code.UsedBy) } + usedByEmail := "" + if code.User != nil { + usedByEmail = code.User.Email + } usedAt := "" if code.UsedAt != nil { usedAt = code.UsedAt.Format("2006-01-02 15:04:05") @@ -224,6 +318,7 @@ func (h *RedeemHandler) Export(c *gin.Context) { fmt.Sprintf("%.2f", code.Value), code.Status, usedBy, + usedByEmail, usedAt, code.CreatedAt.Format("2006-01-02 15:04:05"), }); err != nil { diff --git a/backend/internal/handler/admin/search_truncate_test.go b/backend/internal/handler/admin/search_truncate_test.go new file mode 100644 index 00000000..ffd60e2a --- /dev/null +++ b/backend/internal/handler/admin/search_truncate_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package admin + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑 +func truncateSearchByRune(search string, maxRunes int) string { + if runes := []rune(search); len(runes) > maxRunes { + return string(runes[:maxRunes]) + } + return search +} + +func TestTruncateSearchByRune(t *testing.T) { + tests := []struct { + name string + input string + maxRunes int + wantLen int // 期望的 rune 长度 + }{ + { + name: "纯中文超长", + input: string(make([]rune, 150)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "纯 ASCII 超长", + input: string(make([]byte, 150)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "空字符串", + input: "", + maxRunes: 100, + wantLen: 0, + }, + { + name: "恰好 100 个字符", + input: string(make([]rune, 100)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "不足 100 字符不截断", + input: "hello世界", + maxRunes: 100, + wantLen: 7, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := truncateSearchByRune(tc.input, tc.maxRunes) + require.Equal(t, tc.wantLen, len([]rune(result))) + }) + } +} + +func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) { + // 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8 + input := "" + for i := 0; i < 101; i++ { + input += "中" + } + result := truncateSearchByRune(input, 100) + + require.Equal(t, 100, len([]rune(result))) + // 验证截断结果是有效的 UTF-8(每个中文字符 3 字节) + require.Equal(t, 300, len(result)) +} + +func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) { + // 50 个 ASCII + 51 个中文 = 101 个 rune + input := "" + for i := 0; i < 50; i++ { + input += "a" + } + for i := 0; i < 51; i++ { + input += "中" + } + result := truncateSearchByRune(input, 100) + + runes := []rune(result) + require.Equal(t, 100, len(runes)) + // 前 50 个应该是 'a',后 50 个应该是 '中' + require.Equal(t, 'a', runes[0]) + require.Equal(t, 'a', runes[49]) + require.Equal(t, '中', runes[50]) + require.Equal(t, '中', runes[99]) +} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 1e723ee5..e32c142f 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1,7 +1,13 @@ package admin import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" "log" + "net/http" + "regexp" "strings" "time" @@ -14,21 +20,38 @@ import ( "github.com/gin-gonic/gin" ) +// semverPattern 预编译 semver 格式校验正则 +var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`) + +// menuItemIDPattern validates custom menu item IDs: alphanumeric, hyphens, underscores only. +var menuItemIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + +// generateMenuItemID generates a short random hex ID for a custom menu item. +func generateMenuItemID() (string, error) { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate menu item ID: %w", err) + } + return hex.EncodeToString(b), nil +} + // SettingHandler 系统设置处理器 type SettingHandler struct { settingService *service.SettingService emailService *service.EmailService turnstileService *service.TurnstileService opsService *service.OpsService + soraS3Storage *service.SoraS3Storage } // NewSettingHandler 创建系统设置处理器 -func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler { +func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler { return &SettingHandler{ settingService: settingService, emailService: emailService, turnstileService: turnstileService, opsService: opsService, + soraS3Storage: soraS3Storage, } } @@ -43,6 +66,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { // Check if ops monitoring is enabled (respects config.ops.enabled) opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context()) + defaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(settings.DefaultSubscriptions)) + for _, sub := range settings.DefaultSubscriptions { + defaultSubscriptions = append(defaultSubscriptions, dto.DefaultSubscriptionSetting{ + GroupID: sub.GroupID, + ValidityDays: sub.ValidityDays, + }) + } response.Success(c, dto.SystemSettings{ RegistrationEnabled: settings.RegistrationEnabled, @@ -76,8 +106,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + SoraClientEnabled: settings.SoraClientEnabled, + CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, + DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: settings.EnableModelFallback, FallbackModelAnthropic: settings.FallbackModelAnthropic, FallbackModelOpenAI: settings.FallbackModelOpenAI, @@ -89,6 +122,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled, OpsQueryModeDefault: settings.OpsQueryModeDefault, OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, + MinClaudeCodeVersion: settings.MinClaudeCodeVersion, }) } @@ -123,20 +157,23 @@ type UpdateSettingsRequest struct { LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` // OEM设置 - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` // 默认配置 - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -154,6 +191,8 @@ type UpdateSettingsRequest struct { OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"` OpsQueryModeDefault *string `json:"ops_query_mode_default"` OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"` + + MinClaudeCodeVersion string `json:"min_claude_code_version"` } // UpdateSettings 更新系统设置 @@ -181,6 +220,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { if req.SMTPPort <= 0 { req.SMTPPort = 587 } + req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) // Turnstile 参数验证 if req.TurnstileEnabled { @@ -276,6 +316,84 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + // 自定义菜单项验证 + const ( + maxCustomMenuItems = 20 + maxMenuItemLabelLen = 50 + maxMenuItemURLLen = 2048 + maxMenuItemIconSVGLen = 10 * 1024 // 10KB + maxMenuItemIDLen = 32 + ) + + customMenuJSON := previousSettings.CustomMenuItems + if req.CustomMenuItems != nil { + items := *req.CustomMenuItems + if len(items) > maxCustomMenuItems { + response.BadRequest(c, "Too many custom menu items (max 20)") + return + } + for i, item := range items { + if strings.TrimSpace(item.Label) == "" { + response.BadRequest(c, "Custom menu item label is required") + return + } + if len(item.Label) > maxMenuItemLabelLen { + response.BadRequest(c, "Custom menu item label is too long (max 50 characters)") + return + } + if strings.TrimSpace(item.URL) == "" { + response.BadRequest(c, "Custom menu item URL is required") + return + } + if len(item.URL) > maxMenuItemURLLen { + response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)") + return + } + if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil { + response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL") + return + } + if item.Visibility != "user" && item.Visibility != "admin" { + response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'") + return + } + if len(item.IconSVG) > maxMenuItemIconSVGLen { + response.BadRequest(c, "Custom menu item icon SVG is too large (max 10KB)") + return + } + // Auto-generate ID if missing + if strings.TrimSpace(item.ID) == "" { + id, err := generateMenuItemID() + if err != nil { + response.Error(c, http.StatusInternalServerError, "Failed to generate menu item ID") + return + } + items[i].ID = id + } else if len(item.ID) > maxMenuItemIDLen { + response.BadRequest(c, "Custom menu item ID is too long (max 32 characters)") + return + } else if !menuItemIDPattern.MatchString(item.ID) { + response.BadRequest(c, "Custom menu item ID contains invalid characters (only a-z, A-Z, 0-9, - and _ are allowed)") + return + } + } + // ID uniqueness check + seen := make(map[string]struct{}, len(items)) + for _, item := range items { + if _, exists := seen[item.ID]; exists { + response.BadRequest(c, "Duplicate custom menu item ID: "+item.ID) + return + } + seen[item.ID] = struct{}{} + } + menuBytes, err := json.Marshal(items) + if err != nil { + response.BadRequest(c, "Failed to serialize custom menu items") + return + } + customMenuJSON = string(menuBytes) + } + // Ops metrics collector interval validation (seconds). if req.OpsMetricsIntervalSeconds != nil { v := *req.OpsMetricsIntervalSeconds @@ -287,6 +405,21 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } req.OpsMetricsIntervalSeconds = &v } + defaultSubscriptions := make([]service.DefaultSubscriptionSetting, 0, len(req.DefaultSubscriptions)) + for _, sub := range req.DefaultSubscriptions { + defaultSubscriptions = append(defaultSubscriptions, service.DefaultSubscriptionSetting{ + GroupID: sub.GroupID, + ValidityDays: sub.ValidityDays, + }) + } + + // 验证最低版本号格式(空字符串=禁用,或合法 semver) + if req.MinClaudeCodeVersion != "" { + if !semverPattern.MatchString(req.MinClaudeCodeVersion) { + response.Error(c, http.StatusBadRequest, "min_claude_code_version must be empty or a valid semver (e.g. 2.1.63)") + return + } + } settings := &service.SystemSettings{ RegistrationEnabled: req.RegistrationEnabled, @@ -319,8 +452,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: req.HideCcsImportButton, PurchaseSubscriptionEnabled: purchaseEnabled, PurchaseSubscriptionURL: purchaseURL, + SoraClientEnabled: req.SoraClientEnabled, + CustomMenuItems: customMenuJSON, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, + DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: req.EnableModelFallback, FallbackModelAnthropic: req.FallbackModelAnthropic, FallbackModelOpenAI: req.FallbackModelOpenAI, @@ -328,6 +464,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { FallbackModelAntigravity: req.FallbackModelAntigravity, EnableIdentityPatch: req.EnableIdentityPatch, IdentityPatchPrompt: req.IdentityPatchPrompt, + MinClaudeCodeVersion: req.MinClaudeCodeVersion, OpsMonitoringEnabled: func() bool { if req.OpsMonitoringEnabled != nil { return *req.OpsMonitoringEnabled @@ -367,6 +504,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions)) + for _, sub := range updatedSettings.DefaultSubscriptions { + updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{ + GroupID: sub.GroupID, + ValidityDays: sub.ValidityDays, + }) + } response.Success(c, dto.SystemSettings{ RegistrationEnabled: updatedSettings.RegistrationEnabled, @@ -400,8 +544,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: updatedSettings.HideCcsImportButton, PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, + SoraClientEnabled: updatedSettings.SoraClientEnabled, + CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, + DefaultSubscriptions: updatedDefaultSubscriptions, EnableModelFallback: updatedSettings.EnableModelFallback, FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, @@ -413,6 +560,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled, OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, + MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, }) } @@ -522,6 +670,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.DefaultBalance != after.DefaultBalance { changed = append(changed, "default_balance") } + if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) { + changed = append(changed, "default_subscriptions") + } if before.EnableModelFallback != after.EnableModelFallback { changed = append(changed, "enable_model_fallback") } @@ -555,9 +706,50 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.OpsMetricsIntervalSeconds != after.OpsMetricsIntervalSeconds { changed = append(changed, "ops_metrics_interval_seconds") } + if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion { + changed = append(changed, "min_claude_code_version") + } + if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled { + changed = append(changed, "purchase_subscription_enabled") + } + if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL { + changed = append(changed, "purchase_subscription_url") + } + if before.CustomMenuItems != after.CustomMenuItems { + changed = append(changed, "custom_menu_items") + } return changed } +func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto.DefaultSubscriptionSetting { + if len(input) == 0 { + return nil + } + normalized := make([]dto.DefaultSubscriptionSetting, 0, len(input)) + for _, item := range input { + if item.GroupID <= 0 || item.ValidityDays <= 0 { + continue + } + if item.ValidityDays > service.MaxValidityDays { + item.ValidityDays = service.MaxValidityDays + } + normalized = append(normalized, item) + } + return normalized +} + +func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].GroupID != b[i].GroupID || a[i].ValidityDays != b[i].ValidityDays { + return false + } + } + return true +} + // TestSMTPRequest 测试SMTP连接请求 type TestSMTPRequest struct { SMTPHost string `json:"smtp_host" binding:"required"` @@ -750,6 +942,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) { }) } +func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings { + if settings == nil { + return dto.SoraS3Settings{} + } + return dto.SoraS3Settings{ + Enabled: settings.Enabled, + Endpoint: settings.Endpoint, + Region: settings.Region, + Bucket: settings.Bucket, + AccessKeyID: settings.AccessKeyID, + SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured, + Prefix: settings.Prefix, + ForcePathStyle: settings.ForcePathStyle, + CDNURL: settings.CDNURL, + DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes, + } +} + +func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile { + return dto.SoraS3Profile{ + ProfileID: profile.ProfileID, + Name: profile.Name, + IsActive: profile.IsActive, + Enabled: profile.Enabled, + Endpoint: profile.Endpoint, + Region: profile.Region, + Bucket: profile.Bucket, + AccessKeyID: profile.AccessKeyID, + SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured, + Prefix: profile.Prefix, + ForcePathStyle: profile.ForcePathStyle, + CDNURL: profile.CDNURL, + DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes, + UpdatedAt: profile.UpdatedAt, + } +} + +func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error { + if !enabled { + return nil + } + if strings.TrimSpace(endpoint) == "" { + return fmt.Errorf("S3 Endpoint is required when enabled") + } + if strings.TrimSpace(bucket) == "" { + return fmt.Errorf("S3 Bucket is required when enabled") + } + if strings.TrimSpace(accessKeyID) == "" { + return fmt.Errorf("S3 Access Key ID is required when enabled") + } + if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret { + return nil + } + return fmt.Errorf("S3 Secret Access Key is required when enabled") +} + +func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile { + for idx := range items { + if items[idx].ProfileID == profileID { + return &items[idx] + } + } + return nil +} + +// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口) +// GET /api/v1/admin/settings/sora-s3 +func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) { + settings, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, toSoraS3SettingsDTO(settings)) +} + +// ListSoraS3Profiles 获取 Sora S3 多配置 +// GET /api/v1/admin/settings/sora-s3/profiles +func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) { + result, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + items := make([]dto.SoraS3Profile, 0, len(result.Items)) + for idx := range result.Items { + items = append(items, toSoraS3ProfileDTO(result.Items[idx])) + } + response.Success(c, dto.ListSoraS3ProfilesResponse{ + ActiveProfileID: result.ActiveProfileID, + Items: items, + }) +} + +// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口) +type UpdateSoraS3SettingsRequest struct { + ProfileID string `json:"profile_id"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +type CreateSoraS3ProfileRequest struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + SetActive bool `json:"set_active"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +type UpdateSoraS3ProfileRequest struct { + Name string `json:"name"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +// CreateSoraS3Profile 创建 Sora S3 配置 +// POST /api/v1/admin/settings/sora-s3/profiles +func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) { + var req CreateSoraS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if req.DefaultStorageQuotaBytes < 0 { + req.DefaultStorageQuotaBytes = 0 + } + if strings.TrimSpace(req.Name) == "" { + response.BadRequest(c, "Name is required") + return + } + if strings.TrimSpace(req.ProfileID) == "" { + response.BadRequest(c, "Profile ID is required") + return + } + if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil { + response.BadRequest(c, err.Error()) + return + } + + created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{ + ProfileID: req.ProfileID, + Name: req.Name, + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, + }, req.SetActive) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, toSoraS3ProfileDTO(*created)) +} + +// UpdateSoraS3Profile 更新 Sora S3 配置 +// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id +func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Profile ID is required") + return + } + + var req UpdateSoraS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if req.DefaultStorageQuotaBytes < 0 { + req.DefaultStorageQuotaBytes = 0 + } + if strings.TrimSpace(req.Name) == "" { + response.BadRequest(c, "Name is required") + return + } + + existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + existing := findSoraS3ProfileByID(existingList.Items, profileID) + if existing == nil { + response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound) + return + } + if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil { + response.BadRequest(c, err.Error()) + return + } + + updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{ + Name: req.Name, + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, + }) + if updateErr != nil { + response.ErrorFrom(c, updateErr) + return + } + + response.Success(c, toSoraS3ProfileDTO(*updated)) +} + +// DeleteSoraS3Profile 删除 Sora S3 配置 +// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id +func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Profile ID is required") + return + } + if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +// SetActiveSoraS3Profile 切换激活 Sora S3 配置 +// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate +func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Profile ID is required") + return + } + active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, toSoraS3ProfileDTO(*active)) +} + +// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口) +// PUT /api/v1/admin/settings/sora-s3 +func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) { + var req UpdateSoraS3SettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + existing, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + if req.DefaultStorageQuotaBytes < 0 { + req.DefaultStorageQuotaBytes = 0 + } + if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil { + response.BadRequest(c, err.Error()) + return + } + + settings := &service.SoraS3Settings{ + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, + } + if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil { + response.ErrorFrom(c, err) + return + } + + updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, toSoraS3SettingsDTO(updatedSettings)) +} + +// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket) +// POST /api/v1/admin/settings/sora-s3/test +func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) { + if h.soraS3Storage == nil { + response.Error(c, 500, "S3 存储服务未初始化") + return + } + + var req UpdateSoraS3SettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if !req.Enabled { + response.BadRequest(c, "S3 未启用,无法测试连接") + return + } + + if req.SecretAccessKey == "" { + if req.ProfileID != "" { + profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) + if err == nil { + profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID) + if profile != nil { + req.SecretAccessKey = profile.SecretAccessKey + } + } + } + if req.SecretAccessKey == "" { + existing, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err == nil { + req.SecretAccessKey = existing.SecretAccessKey + } + } + } + + testCfg := &service.SoraS3Settings{ + Enabled: true, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + } + if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil { + response.Error(c, 400, "S3 连接测试失败: "+err.Error()) + return + } + response.Success(c, gin.H{"message": "S3 连接成功"}) +} + // UpdateStreamTimeoutSettingsRequest 更新流超时配置请求 type UpdateStreamTimeoutSettingsRequest struct { Enabled bool `json:"enabled"` diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index 51995ab1..e5b6db13 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "strconv" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { return } - subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days) - if err != nil { - response.ErrorFrom(c, err) - return + idempotencyPayload := struct { + SubscriptionID int64 `json:"subscription_id"` + Body AdjustSubscriptionRequest `json:"body"` + }{ + SubscriptionID: subscriptionID, + Body: req, } - - response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription)) + executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days) + if execErr != nil { + return nil, execErr + } + return dto.UserSubscriptionFromServiceAdmin(subscription), nil + }) } // Revoke handles revoking a subscription diff --git a/backend/internal/handler/admin/system_handler.go b/backend/internal/handler/admin/system_handler.go index 22442a4e..a061cd31 100644 --- a/backend/internal/handler/admin/system_handler.go +++ b/backend/internal/handler/admin/system_handler.go @@ -1,11 +1,15 @@ package admin import ( + "context" "net/http" + "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/sysutil" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -14,12 +18,14 @@ import ( // SystemHandler handles system-related operations type SystemHandler struct { updateSvc *service.UpdateService + lockSvc *service.SystemOperationLockService } // NewSystemHandler creates a new SystemHandler -func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler { +func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler { return &SystemHandler{ updateSvc: updateSvc, + lockSvc: lockSvc, } } @@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) { // PerformUpdate downloads and applies the update // POST /api/v1/admin/system/update func (h *SystemHandler) PerformUpdate(c *gin.Context) { - if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil { - response.Error(c, http.StatusInternalServerError, err.Error()) - return - } - response.Success(c, gin.H{ - "message": "Update completed. Please restart the service.", - "need_restart": true, + operationID := buildSystemOperationID(c, "update") + payload := gin.H{"operation_id": operationID} + executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) { + lock, release, err := h.acquireSystemLock(ctx, operationID) + if err != nil { + return nil, err + } + var releaseReason string + succeeded := false + defer func() { + release(releaseReason, succeeded) + }() + + if err := h.updateSvc.PerformUpdate(ctx); err != nil { + releaseReason = "SYSTEM_UPDATE_FAILED" + return nil, err + } + succeeded = true + + return gin.H{ + "message": "Update completed. Please restart the service.", + "need_restart": true, + "operation_id": lock.OperationID(), + }, nil }) } // Rollback restores the previous version // POST /api/v1/admin/system/rollback func (h *SystemHandler) Rollback(c *gin.Context) { - if err := h.updateSvc.Rollback(); err != nil { - response.Error(c, http.StatusInternalServerError, err.Error()) - return - } - response.Success(c, gin.H{ - "message": "Rollback completed. Please restart the service.", - "need_restart": true, + operationID := buildSystemOperationID(c, "rollback") + payload := gin.H{"operation_id": operationID} + executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) { + lock, release, err := h.acquireSystemLock(ctx, operationID) + if err != nil { + return nil, err + } + var releaseReason string + succeeded := false + defer func() { + release(releaseReason, succeeded) + }() + + if err := h.updateSvc.Rollback(); err != nil { + releaseReason = "SYSTEM_ROLLBACK_FAILED" + return nil, err + } + succeeded = true + + return gin.H{ + "message": "Rollback completed. Please restart the service.", + "need_restart": true, + "operation_id": lock.OperationID(), + }, nil }) } // RestartService restarts the systemd service // POST /api/v1/admin/system/restart func (h *SystemHandler) RestartService(c *gin.Context) { - // Schedule service restart in background after sending response - // This ensures the client receives the success response before the service restarts - go func() { - // Wait a moment to ensure the response is sent - time.Sleep(500 * time.Millisecond) - sysutil.RestartServiceAsync() - }() + operationID := buildSystemOperationID(c, "restart") + payload := gin.H{"operation_id": operationID} + executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) { + lock, release, err := h.acquireSystemLock(ctx, operationID) + if err != nil { + return nil, err + } + succeeded := false + defer func() { + release("", succeeded) + }() - response.Success(c, gin.H{ - "message": "Service restart initiated", + // Schedule service restart in background after sending response + // This ensures the client receives the success response before the service restarts + go func() { + // Wait a moment to ensure the response is sent + time.Sleep(500 * time.Millisecond) + sysutil.RestartServiceAsync() + }() + succeeded = true + return gin.H{ + "message": "Service restart initiated", + "operation_id": lock.OperationID(), + }, nil }) } + +func (h *SystemHandler) acquireSystemLock( + ctx context.Context, + operationID string, +) (*service.SystemOperationLock, func(string, bool), error) { + if h.lockSvc == nil { + return nil, nil, service.ErrIdempotencyStoreUnavail + } + lock, err := h.lockSvc.Acquire(ctx, operationID) + if err != nil { + return nil, nil, err + } + release := func(reason string, succeeded bool) { + releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason) + } + return lock, release, nil +} + +func buildSystemOperationID(c *gin.Context, operation string) string { + key := strings.TrimSpace(c.GetHeader("Idempotency-Key")) + if key == "" { + return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36) + } + actorScope := "admin:0" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10) + } + seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key + hash := service.HashIdempotencyKey(seed) + if len(hash) > 24 { + hash = hash[:24] + } + return "sysop-" + hash +} diff --git a/backend/internal/handler/admin/usage_cleanup_handler_test.go b/backend/internal/handler/admin/usage_cleanup_handler_test.go index ed1c7cc2..6152d5e9 100644 --- a/backend/internal/handler/admin/usage_cleanup_handler_test.go +++ b/backend/internal/handler/admin/usage_cleanup_handler_test.go @@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) { require.Equal(t, http.StatusBadRequest, recorder.Code) } +func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 88) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "timezone": "UTC", + "request_type": "invalid", + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 99) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "timezone": "UTC", + "request_type": "ws_v2", + "stream": false, + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.created, 1) + created := repo.created[0] + require.NotNil(t, created.Filters.RequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType) + require.Nil(t, created.Filters.Stream) +} + +func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 99) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "timezone": "UTC", + "stream": true, + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.created, 1) + created := repo.created[0] + require.Nil(t, created.Filters.RequestType) + require.NotNil(t, created.Filters.Stream) + require.True(t, *created.Filters.Stream) +} + func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) { repo := &cleanupRepoStub{} cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index 3f3238dd..d0bba773 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -1,13 +1,14 @@ package admin import ( - "log" + "context" "net/http" "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" @@ -50,6 +51,7 @@ type CreateUsageCleanupTaskRequest struct { AccountID *int64 `json:"account_id"` GroupID *int64 `json:"group_id"` Model *string `json:"model"` + RequestType *string `json:"request_type"` Stream *bool `json:"stream"` BillingType *int8 `json:"billing_type"` Timezone string `json:"timezone"` @@ -100,8 +102,17 @@ func (h *UsageHandler) List(c *gin.Context) { model := c.Query("model") + var requestType *int16 var stream *bool - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") @@ -151,6 +162,7 @@ func (h *UsageHandler) List(c *gin.Context) { AccountID: accountID, GroupID: groupID, Model: model, + RequestType: requestType, Stream: stream, BillingType: billingType, StartTime: startTime, @@ -213,8 +225,17 @@ func (h *UsageHandler) Stats(c *gin.Context) { model := c.Query("model") + var requestType *int16 var stream *bool - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") @@ -277,6 +298,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { AccountID: accountID, GroupID: groupID, Model: model, + RequestType: requestType, Stream: stream, BillingType: billingType, StartTime: &startTime, @@ -378,11 +400,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) { operator = subject.UserID } page, pageSize := response.ParsePagination(c) - log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize) params := pagination.PaginationParams{Page: page, PageSize: pageSize} tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params) if err != nil { - log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err) response.ErrorFrom(c, err) return } @@ -390,7 +412,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) { for i := range tasks { out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i])) } - log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize) response.Paginated(c, out, result.Total, page, pageSize) } @@ -431,6 +453,19 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { } endTime = endTime.Add(24*time.Hour - time.Nanosecond) + var requestType *int16 + stream := req.Stream + if req.RequestType != nil { + parsed, err := service.ParseUsageRequestType(*req.RequestType) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + stream = nil + } + filters := service.UsageCleanupFilters{ StartTime: startTime, EndTime: endTime, @@ -439,7 +474,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { AccountID: req.AccountID, GroupID: req.GroupID, Model: req.Model, - Stream: req.Stream, + RequestType: requestType, + Stream: stream, BillingType: req.BillingType, } @@ -463,38 +499,50 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { if filters.Model != nil { model = *filters.Model } - var stream any + var streamValue any if filters.Stream != nil { - stream = *filters.Stream + streamValue = *filters.Stream + } + var requestTypeName any + if filters.RequestType != nil { + requestTypeName = service.RequestTypeFromInt16(*filters.RequestType).String() } var billingType any if filters.BillingType != nil { billingType = *filters.BillingType } - log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", - subject.UserID, - filters.StartTime.Format(time.RFC3339), - filters.EndTime.Format(time.RFC3339), - userID, - apiKeyID, - accountID, - groupID, - model, - stream, - billingType, - req.Timezone, - ) - - task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID) - if err != nil { - log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err) - response.ErrorFrom(c, err) - return + idempotencyPayload := struct { + OperatorID int64 `json:"operator_id"` + Body CreateUsageCleanupTaskRequest `json:"body"` + }{ + OperatorID: subject.UserID, + Body: req, } + executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q", + subject.UserID, + filters.StartTime.Format(time.RFC3339), + filters.EndTime.Format(time.RFC3339), + userID, + apiKeyID, + accountID, + groupID, + model, + requestTypeName, + streamValue, + billingType, + req.Timezone, + ) - log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status) - response.Success(c, dto.UsageCleanupTaskFromService(task)) + task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID) + if err != nil { + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err) + return nil, err + } + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status) + return dto.UsageCleanupTaskFromService(task), nil + }) } // CancelCleanupTask handles canceling a usage cleanup task @@ -515,12 +563,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) { response.BadRequest(c, "Invalid task id") return } - log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID) if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil { - log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err) response.ErrorFrom(c, err) return } - log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID) response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled}) } diff --git a/backend/internal/handler/admin/usage_handler_request_type_test.go b/backend/internal/handler/admin/usage_handler_request_type_test.go new file mode 100644 index 00000000..21add574 --- /dev/null +++ b/backend/internal/handler/admin/usage_handler_request_type_test.go @@ -0,0 +1,117 @@ +package admin + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type adminUsageRepoCapture struct { + service.UsageLogRepository + listFilters usagestats.UsageLogFilters + statsFilters usagestats.UsageLogFilters +} + +func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + s.listFilters = filters + return []service.UsageLog{}, &pagination.PaginationResult{ + Total: 0, + Page: params.Page, + PageSize: params.PageSize, + Pages: 0, + }, nil +} + +func (s *adminUsageRepoCapture) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { + s.statsFilters = filters + return &usagestats.UsageStats{}, nil +} + +func newAdminUsageRequestTypeTestRouter(repo *adminUsageRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + usageSvc := service.NewUsageService(repo, nil, nil, nil) + handler := NewUsageHandler(usageSvc, nil, nil, nil) + router := gin.New() + router.GET("/admin/usage", handler.List) + router.GET("/admin/usage/stats", handler.Stats) + return router +} + +func TestAdminUsageListRequestTypePriority(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=ws_v2&stream=false", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.listFilters.RequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType) + require.Nil(t, repo.listFilters.Stream) +} + +func TestAdminUsageListInvalidRequestType(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageListInvalidStream(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageStatsRequestTypePriority(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=stream&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.statsFilters.RequestType) + require.Equal(t, int16(service.RequestTypeStream), *repo.statsFilters.RequestType) + require.Nil(t, repo.statsFilters.Stream) +} + +func TestAdminUsageStatsInvalidRequestType(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=oops", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageStatsInvalidStream(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?stream=oops", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 1c772e7d..f85c060e 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "strconv" "strings" @@ -11,27 +12,36 @@ import ( "github.com/gin-gonic/gin" ) +// UserWithConcurrency wraps AdminUser with current concurrency info +type UserWithConcurrency struct { + dto.AdminUser + CurrentConcurrency int `json:"current_concurrency"` +} + // UserHandler handles admin user management type UserHandler struct { - adminService service.AdminService + adminService service.AdminService + concurrencyService *service.ConcurrencyService } // NewUserHandler creates a new admin user handler -func NewUserHandler(adminService service.AdminService) *UserHandler { +func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler { return &UserHandler{ - adminService: adminService, + adminService: adminService, + concurrencyService: concurrencyService, } } // CreateUserRequest represents admin create user request type CreateUserRequest struct { - Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required,min=6"` - Username string `json:"username"` - Notes string `json:"notes"` - Balance float64 `json:"balance"` - Concurrency int `json:"concurrency"` - AllowedGroups []int64 `json:"allowed_groups"` + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required,min=6"` + Username string `json:"username"` + Notes string `json:"notes"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` + AllowedGroups []int64 `json:"allowed_groups"` + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` } // UpdateUserRequest represents admin update user request @@ -47,7 +57,8 @@ type UpdateUserRequest struct { AllowedGroups *[]int64 `json:"allowed_groups"` // GroupRates 用户专属分组倍率配置 // map[groupID]*rate,nil 表示删除该分组的专属倍率 - GroupRates map[int64]*float64 `json:"group_rates"` + GroupRates map[int64]*float64 `json:"group_rates"` + SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` } // UpdateBalanceRequest represents balance update request @@ -70,8 +81,8 @@ func (h *UserHandler) List(c *gin.Context) { search := c.Query("search") // 标准化和验证 search 参数 search = strings.TrimSpace(search) - if len(search) > 100 { - search = search[:100] + if runes := []rune(search); len(runes) > 100 { + search = string(runes[:100]) } filters := service.UserListFilters{ @@ -87,10 +98,30 @@ func (h *UserHandler) List(c *gin.Context) { return } - out := make([]dto.AdminUser, 0, len(users)) - for i := range users { - out = append(out, *dto.UserFromServiceAdmin(&users[i])) + // Batch get current concurrency (nil map if unavailable) + var loadInfo map[int64]*service.UserLoadInfo + if len(users) > 0 && h.concurrencyService != nil { + usersConcurrency := make([]service.UserWithConcurrency, len(users)) + for i := range users { + usersConcurrency[i] = service.UserWithConcurrency{ + ID: users[i].ID, + MaxConcurrency: users[i].Concurrency, + } + } + loadInfo, _ = h.concurrencyService.GetUsersLoadBatch(c.Request.Context(), usersConcurrency) } + + // Build response with concurrency info + out := make([]UserWithConcurrency, len(users)) + for i := range users { + out[i] = UserWithConcurrency{ + AdminUser: *dto.UserFromServiceAdmin(&users[i]), + } + if info := loadInfo[users[i].ID]; info != nil { + out[i].CurrentConcurrency = info.CurrentConcurrency + } + } + response.Paginated(c, out, total, page, pageSize) } @@ -145,13 +176,14 @@ func (h *UserHandler) Create(c *gin.Context) { } user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{ - Email: req.Email, - Password: req.Password, - Username: req.Username, - Notes: req.Notes, - Balance: req.Balance, - Concurrency: req.Concurrency, - AllowedGroups: req.AllowedGroups, + Email: req.Email, + Password: req.Password, + Username: req.Username, + Notes: req.Notes, + Balance: req.Balance, + Concurrency: req.Concurrency, + AllowedGroups: req.AllowedGroups, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, }) if err != nil { response.ErrorFrom(c, err) @@ -178,15 +210,16 @@ func (h *UserHandler) Update(c *gin.Context) { // 使用指针类型直接传递,nil 表示未提供该字段 user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{ - Email: req.Email, - Password: req.Password, - Username: req.Username, - Notes: req.Notes, - Balance: req.Balance, - Concurrency: req.Concurrency, - Status: req.Status, - AllowedGroups: req.AllowedGroups, - GroupRates: req.GroupRates, + Email: req.Email, + Password: req.Password, + Username: req.Username, + Notes: req.Notes, + Balance: req.Balance, + Concurrency: req.Concurrency, + Status: req.Status, + AllowedGroups: req.AllowedGroups, + GroupRates: req.GroupRates, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, }) if err != nil { response.ErrorFrom(c, err) @@ -229,13 +262,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) { return } - user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes) - if err != nil { - response.ErrorFrom(c, err) - return + idempotencyPayload := struct { + UserID int64 `json:"user_id"` + Body UpdateBalanceRequest `json:"body"` + }{ + UserID: userID, + Body: req, } - - response.Success(c, dto.UserFromServiceAdmin(user)) + executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes) + if execErr != nil { + return nil, execErr + } + return dto.UserFromServiceAdmin(user), nil + }) } // GetUserAPIKeys handles getting user's API keys diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index f1a18ad2..61762744 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -2,6 +2,7 @@ package handler import ( + "context" "strconv" "time" @@ -130,13 +131,14 @@ func (h *APIKeyHandler) Create(c *gin.Context) { if req.Quota != nil { svcReq.Quota = *req.Quota } - key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, dto.APIKeyFromService(key)) + executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq) + if err != nil { + return nil, err + } + return dto.APIKeyFromService(key), nil + }) } // Update handles updating an API key diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 34ed63bc..1ffa9d71 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -2,6 +2,7 @@ package handler import ( "log/slog" + "strings" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -112,12 +113,10 @@ func (h *AuthHandler) Register(c *gin.Context) { return } - // Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过) - if req.VerifyCode == "" { - if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { - response.ErrorFrom(c, err) - return - } + // Turnstile 验证(邮箱验证码注册场景避免重复校验一次性 token) + if err := h.authService.VerifyTurnstileForRegister(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c), req.VerifyCode); err != nil { + response.ErrorFrom(c, err) + return } _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) @@ -448,17 +447,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) { return } - // Build frontend base URL from request - scheme := "https" - if c.Request.TLS == nil { - // Check X-Forwarded-Proto header (common in reverse proxy setups) - if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" { - scheme = proto - } else { - scheme = "http" - } + frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL) + if frontendBaseURL == "" { + slog.Error("server.frontend_url not configured; cannot build password reset link") + response.InternalError(c, "Password reset is not configured") + return } - frontendBaseURL := scheme + "://" + c.Request.Host // Request password reset (async) // Note: This returns success even if email doesn't exist (to prevent enumeration) diff --git a/backend/internal/handler/dto/api_key_mapper_last_used_test.go b/backend/internal/handler/dto/api_key_mapper_last_used_test.go new file mode 100644 index 00000000..99644ced --- /dev/null +++ b/backend/internal/handler/dto/api_key_mapper_last_used_test.go @@ -0,0 +1,40 @@ +package dto + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestAPIKeyFromService_MapsLastUsedAt(t *testing.T) { + lastUsed := time.Now().UTC().Truncate(time.Second) + src := &service.APIKey{ + ID: 1, + UserID: 2, + Key: "sk-map-last-used", + Name: "Mapper", + Status: service.StatusActive, + LastUsedAt: &lastUsed, + } + + out := APIKeyFromService(src) + require.NotNil(t, out) + require.NotNil(t, out.LastUsedAt) + require.WithinDuration(t, lastUsed, *out.LastUsedAt, time.Second) +} + +func TestAPIKeyFromService_MapsNilLastUsedAt(t *testing.T) { + src := &service.APIKey{ + ID: 1, + UserID: 2, + Key: "sk-map-last-used-nil", + Name: "MapperNil", + Status: service.StatusActive, + } + + out := APIKeyFromService(src) + require.NotNil(t, out) + require.Nil(t, out.LastUsedAt) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index d14ab1d1..1c34f537 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -2,6 +2,7 @@ package dto import ( + "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -58,9 +59,11 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { return nil } return &AdminUser{ - User: *base, - Notes: u.Notes, - GroupRates: u.GroupRates, + User: *base, + Notes: u.Notes, + GroupRates: u.GroupRates, + SoraStorageQuotaBytes: u.SoraStorageQuotaBytes, + SoraStorageUsedBytes: u.SoraStorageUsedBytes, } } @@ -77,6 +80,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey { Status: k.Status, IPWhitelist: k.IPWhitelist, IPBlacklist: k.IPBlacklist, + LastUsedAt: k.LastUsedAt, Quota: k.Quota, QuotaUsed: k.QuotaUsed, ExpiresAt: k.ExpiresAt, @@ -115,6 +119,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { MCPXMLInject: g.MCPXMLInject, SupportedModelScopes: g.SupportedModelScopes, AccountCount: g.AccountCount, + SortOrder: g.SortOrder, } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) @@ -128,24 +133,28 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { func groupFromServiceBase(g *service.Group) Group { return Group{ - ID: g.ID, - Name: g.Name, - Description: g.Description, - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUSD, - WeeklyLimitUSD: g.WeeklyLimitUSD, - MonthlyLimitUSD: g.MonthlyLimitUSD, - ImagePrice1K: g.ImagePrice1K, - ImagePrice2K: g.ImagePrice2K, - ImagePrice4K: g.ImagePrice4K, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - // 无效请求兜底分组 + ID: g.ID, + Name: g.Name, + Description: g.Description, + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUSD, + WeeklyLimitUSD: g.WeeklyLimitUSD, + MonthlyLimitUSD: g.MonthlyLimitUSD, + ImagePrice1K: g.ImagePrice1K, + ImagePrice2K: g.ImagePrice2K, + ImagePrice4K: g.ImagePrice4K, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, + SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } @@ -200,6 +209,17 @@ func AccountFromServiceShallow(a *service.Account) *Account { if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 { out.SessionIdleTimeoutMin = &idleTimeout } + if rpm := a.GetBaseRPM(); rpm > 0 { + out.BaseRPM = &rpm + strategy := a.GetRPMStrategy() + out.RPMStrategy = &strategy + buffer := a.GetRPMStickyBuffer() + out.RPMStickyBuffer = &buffer + } + // 用户消息队列模式 + if mode := a.GetUserMsgQueueMode(); mode != "" { + out.UserMsgQueueMode = &mode + } // TLS指纹伪装开关 if a.IsTLSFingerprintEnabled() { enabled := true @@ -210,6 +230,13 @@ func AccountFromServiceShallow(a *service.Account) *Account { enabled := true out.EnableSessionIDMasking = &enabled } + // 缓存 TTL 强制替换 + if a.IsCacheTTLOverrideEnabled() { + enabled := true + out.CacheTTLOverrideEnabled = &enabled + target := a.GetCacheTTLOverrideTarget() + out.CacheTTLOverrideTarget = &target + } } return out @@ -270,7 +297,6 @@ func ProxyFromService(p *service.Proxy) *Proxy { Host: p.Host, Port: p.Port, Username: p.Username, - Password: p.Password, Status: p.Status, CreatedAt: p.CreatedAt, UpdatedAt: p.UpdatedAt, @@ -292,6 +318,56 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi CountryCode: p.CountryCode, Region: p.Region, City: p.City, + QualityStatus: p.QualityStatus, + QualityScore: p.QualityScore, + QualityGrade: p.QualityGrade, + QualitySummary: p.QualitySummary, + QualityChecked: p.QualityChecked, + } +} + +// ProxyFromServiceAdmin converts a service Proxy to AdminProxy DTO for admin users. +// It includes the password field - user-facing endpoints must not use this. +func ProxyFromServiceAdmin(p *service.Proxy) *AdminProxy { + if p == nil { + return nil + } + base := ProxyFromService(p) + if base == nil { + return nil + } + return &AdminProxy{ + Proxy: *base, + Password: p.Password, + } +} + +// ProxyWithAccountCountFromServiceAdmin converts a service ProxyWithAccountCount to AdminProxyWithAccountCount DTO. +// It includes the password field - user-facing endpoints must not use this. +func ProxyWithAccountCountFromServiceAdmin(p *service.ProxyWithAccountCount) *AdminProxyWithAccountCount { + if p == nil { + return nil + } + admin := ProxyFromServiceAdmin(&p.Proxy) + if admin == nil { + return nil + } + return &AdminProxyWithAccountCount{ + AdminProxy: *admin, + AccountCount: p.AccountCount, + LatencyMs: p.LatencyMs, + LatencyStatus: p.LatencyStatus, + LatencyMessage: p.LatencyMessage, + IPAddress: p.IPAddress, + Country: p.Country, + CountryCode: p.CountryCode, + Region: p.Region, + City: p.City, + QualityStatus: p.QualityStatus, + QualityScore: p.QualityScore, + QualityGrade: p.QualityGrade, + QualitySummary: p.QualitySummary, + QualityChecked: p.QualityChecked, } } @@ -367,6 +443,8 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary { func usageLogFromServiceUser(l *service.UsageLog) UsageLog { // 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。 + requestType := l.EffectiveRequestType() + stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode) return UsageLog{ ID: l.ID, UserID: l.UserID, @@ -391,12 +469,16 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ActualCost: l.ActualCost, RateMultiplier: l.RateMultiplier, BillingType: l.BillingType, - Stream: l.Stream, + RequestType: requestType.String(), + Stream: stream, + OpenAIWSMode: openAIWSMode, DurationMs: l.DurationMs, FirstTokenMs: l.FirstTokenMs, ImageCount: l.ImageCount, ImageSize: l.ImageSize, + MediaType: l.MediaType, UserAgent: l.UserAgent, + CacheTTLOverridden: l.CacheTTLOverridden, CreatedAt: l.CreatedAt, User: UserFromServiceShallow(l.User), APIKey: APIKeyFromService(l.APIKey), @@ -444,6 +526,7 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa AccountID: task.Filters.AccountID, GroupID: task.Filters.GroupID, Model: task.Filters.Model, + RequestType: requestTypeStringPtr(task.Filters.RequestType), Stream: task.Filters.Stream, BillingType: task.Filters.BillingType, }, @@ -459,6 +542,14 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa } } +func requestTypeStringPtr(requestType *int16) *string { + if requestType == nil { + return nil + } + value := service.RequestTypeFromInt16(*requestType).String() + return &value +} + func SettingFromService(s *service.Setting) *Setting { if s == nil { return nil @@ -523,11 +614,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult for i := range r.Subscriptions { subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i])) } + statuses := make(map[string]string, len(r.Statuses)) + for userID, status := range r.Statuses { + statuses[strconv.FormatInt(userID, 10)] = status + } return &BulkAssignResult{ SuccessCount: r.SuccessCount, + CreatedCount: r.CreatedCount, + ReusedCount: r.ReusedCount, FailedCount: r.FailedCount, Subscriptions: subs, Errors: r.Errors, + Statuses: statuses, } } diff --git a/backend/internal/handler/dto/mappers_usage_test.go b/backend/internal/handler/dto/mappers_usage_test.go new file mode 100644 index 00000000..d716bdc4 --- /dev/null +++ b/backend/internal/handler/dto/mappers_usage_test.go @@ -0,0 +1,73 @@ +package dto + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUsageLogFromService_IncludesOpenAIWSMode(t *testing.T) { + t.Parallel() + + wsLog := &service.UsageLog{ + RequestID: "req_1", + Model: "gpt-5.3-codex", + OpenAIWSMode: true, + } + httpLog := &service.UsageLog{ + RequestID: "resp_1", + Model: "gpt-5.3-codex", + OpenAIWSMode: false, + } + + require.True(t, UsageLogFromService(wsLog).OpenAIWSMode) + require.False(t, UsageLogFromService(httpLog).OpenAIWSMode) + require.True(t, UsageLogFromServiceAdmin(wsLog).OpenAIWSMode) + require.False(t, UsageLogFromServiceAdmin(httpLog).OpenAIWSMode) +} + +func TestUsageLogFromService_PrefersRequestTypeForLegacyFields(t *testing.T) { + t.Parallel() + + log := &service.UsageLog{ + RequestID: "req_2", + Model: "gpt-5.3-codex", + RequestType: service.RequestTypeWSV2, + Stream: false, + OpenAIWSMode: false, + } + + userDTO := UsageLogFromService(log) + adminDTO := UsageLogFromServiceAdmin(log) + + require.Equal(t, "ws_v2", userDTO.RequestType) + require.True(t, userDTO.Stream) + require.True(t, userDTO.OpenAIWSMode) + require.Equal(t, "ws_v2", adminDTO.RequestType) + require.True(t, adminDTO.Stream) + require.True(t, adminDTO.OpenAIWSMode) +} + +func TestUsageCleanupTaskFromService_RequestTypeMapping(t *testing.T) { + t.Parallel() + + requestType := int16(service.RequestTypeStream) + task := &service.UsageCleanupTask{ + ID: 1, + Status: service.UsageCleanupStatusPending, + Filters: service.UsageCleanupFilters{ + RequestType: &requestType, + }, + } + + dtoTask := UsageCleanupTaskFromService(task) + require.NotNil(t, dtoTask) + require.NotNil(t, dtoTask.Filters.RequestType) + require.Equal(t, "stream", *dtoTask.Filters.RequestType) +} + +func TestRequestTypeStringPtrNil(t *testing.T) { + t.Parallel() + require.Nil(t, requestTypeStringPtr(nil)) +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index be94bc16..beb03e67 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -1,5 +1,20 @@ package dto +import ( + "encoding/json" + "strings" +) + +// CustomMenuItem represents a user-configured custom menu entry. +type CustomMenuItem struct { + ID string `json:"id"` + Label string `json:"label"` + IconSVG string `json:"icon_svg"` + URL string `json:"url"` + Visibility string `json:"visibility"` // "user" or "admin" + SortOrder int `json:"sort_order"` +} + // SystemSettings represents the admin settings API response payload. type SystemSettings struct { RegistrationEnabled bool `json:"registration_enabled"` @@ -27,19 +42,22 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -57,29 +75,76 @@ type SystemSettings struct { OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"` OpsQueryModeDefault string `json:"ops_query_mode_default"` OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"` + + MinClaudeCodeVersion string `json:"min_claude_code_version"` +} + +type DefaultSubscriptionSetting struct { + GroupID int64 `json:"group_id"` + ValidityDays int `json:"validity_days"` } type PublicSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - Version string `json:"version"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + SoraClientEnabled bool `json:"sora_client_enabled"` + Version string `json:"version"` +} + +// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段) +type SoraS3Settings struct { + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段) +type SoraS3Profile struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` + UpdatedAt string `json:"updated_at"` +} + +// ListSoraS3ProfilesResponse Sora S3 配置列表响应 +type ListSoraS3ProfilesResponse struct { + ActiveProfileID string `json:"active_profile_id"` + Items []SoraS3Profile `json:"items"` } // StreamTimeoutSettings 流超时处理配置 DTO @@ -90,3 +155,29 @@ type StreamTimeoutSettings struct { ThresholdCount int `json:"threshold_count"` ThresholdWindowMinutes int `json:"threshold_window_minutes"` } + +// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. +// Returns empty slice on empty/invalid input. +func ParseCustomMenuItems(raw string) []CustomMenuItem { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return []CustomMenuItem{} + } + var items []CustomMenuItem + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return []CustomMenuItem{} + } + return items +} + +// ParseUserVisibleMenuItems parses custom menu items and filters out admin-only entries. +func ParseUserVisibleMenuItems(raw string) []CustomMenuItem { + items := ParseCustomMenuItems(raw) + filtered := make([]CustomMenuItem, 0, len(items)) + for _, item := range items { + if item.Visibility != "admin" { + filtered = append(filtered, item) + } + } + return filtered +} diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 71bb1ed4..e9235797 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -2,11 +2,6 @@ package dto import "time" -type ScopeRateLimitInfo struct { - ResetAt time.Time `json:"reset_at"` - RemainingSec int64 `json:"remaining_sec"` -} - type User struct { ID int64 `json:"id"` Email string `json:"email"` @@ -31,7 +26,9 @@ type AdminUser struct { Notes string `json:"notes"` // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier - GroupRates map[int64]float64 `json:"group_rates,omitempty"` + GroupRates map[int64]float64 `json:"group_rates,omitempty"` + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"` } type APIKey struct { @@ -43,6 +40,7 @@ type APIKey struct { Status string `json:"status"` IPWhitelist []string `json:"ip_whitelist"` IPBlacklist []string `json:"ip_blacklist"` + LastUsedAt *time.Time `json:"last_used_at"` Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires) @@ -72,12 +70,21 @@ type Group struct { ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + // Sora 按次计费配置 + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` + // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` // 无效请求兜底分组 FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` + // Sora 存储配额 + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } @@ -98,6 +105,9 @@ type AdminGroup struct { SupportedModelScopes []string `json:"supported_model_scopes"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountCount int64 `json:"account_count,omitempty"` + + // 分组排序 + SortOrder int `json:"sort_order"` } type Account struct { @@ -126,9 +136,6 @@ type Account struct { RateLimitResetAt *time.Time `json:"rate_limit_reset_at"` OverloadUntil *time.Time `json:"overload_until"` - // Antigravity scope 级限流状态(从 extra 提取) - ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"` - TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"` TempUnschedulableReason string `json:"temp_unschedulable_reason"` @@ -146,6 +153,13 @@ type Account struct { MaxSessions *int `json:"max_sessions,omitempty"` SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"` + // RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效) + // 从 extra 字段提取,方便前端显示和编辑 + BaseRPM *int `json:"base_rpm,omitempty"` + RPMStrategy *string `json:"rpm_strategy,omitempty"` + RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"` + UserMsgQueueMode *string `json:"user_msg_queue_mode,omitempty"` + // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效) // 从 extra 字段提取,方便前端显示和编辑 EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"` @@ -155,6 +169,11 @@ type Account struct { // 从 extra 字段提取,方便前端显示和编辑 EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"` + // 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效) + // 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费 + CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"` + CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` @@ -196,6 +215,37 @@ type ProxyWithAccountCount struct { CountryCode string `json:"country_code,omitempty"` Region string `json:"region,omitempty"` City string `json:"city,omitempty"` + QualityStatus string `json:"quality_status,omitempty"` + QualityScore *int `json:"quality_score,omitempty"` + QualityGrade string `json:"quality_grade,omitempty"` + QualitySummary string `json:"quality_summary,omitempty"` + QualityChecked *int64 `json:"quality_checked,omitempty"` +} + +// AdminProxy 是管理员接口使用的 proxy DTO(包含密码等敏感字段)。 +// 注意:普通接口不得使用此 DTO。 +type AdminProxy struct { + Proxy + Password string `json:"password,omitempty"` +} + +// AdminProxyWithAccountCount 是管理员接口使用的带账号统计的 proxy DTO。 +type AdminProxyWithAccountCount struct { + AdminProxy + AccountCount int64 `json:"account_count"` + LatencyMs *int64 `json:"latency_ms,omitempty"` + LatencyStatus string `json:"latency_status,omitempty"` + LatencyMessage string `json:"latency_message,omitempty"` + IPAddress string `json:"ip_address,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` + Region string `json:"region,omitempty"` + City string `json:"city,omitempty"` + QualityStatus string `json:"quality_status,omitempty"` + QualityScore *int `json:"quality_score,omitempty"` + QualityGrade string `json:"quality_grade,omitempty"` + QualitySummary string `json:"quality_summary,omitempty"` + QualityChecked *int64 `json:"quality_checked,omitempty"` } type ProxyAccountSummary struct { @@ -266,18 +316,24 @@ type UsageLog struct { ActualCost float64 `json:"actual_cost"` RateMultiplier float64 `json:"rate_multiplier"` - BillingType int8 `json:"billing_type"` - Stream bool `json:"stream"` - DurationMs *int `json:"duration_ms"` - FirstTokenMs *int `json:"first_token_ms"` + BillingType int8 `json:"billing_type"` + RequestType string `json:"request_type"` + Stream bool `json:"stream"` + OpenAIWSMode bool `json:"openai_ws_mode"` + DurationMs *int `json:"duration_ms"` + FirstTokenMs *int `json:"first_token_ms"` // 图片生成字段 ImageCount int `json:"image_count"` ImageSize *string `json:"image_size"` + MediaType *string `json:"media_type"` // User-Agent UserAgent *string `json:"user_agent"` + // Cache TTL Override 标记 + CacheTTLOverridden bool `json:"cache_ttl_overridden"` + CreatedAt time.Time `json:"created_at"` User *User `json:"user,omitempty"` @@ -308,6 +364,7 @@ type UsageCleanupFilters struct { AccountID *int64 `json:"account_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"` Model *string `json:"model,omitempty"` + RequestType *string `json:"request_type,omitempty"` Stream *bool `json:"stream,omitempty"` BillingType *int8 `json:"billing_type,omitempty"` } @@ -379,9 +436,12 @@ type AdminUserSubscription struct { type BulkAssignResult struct { SuccessCount int `json:"success_count"` + CreatedCount int `json:"created_count"` + ReusedCount int `json:"reused_count"` FailedCount int `json:"failed_count"` Subscriptions []AdminUserSubscription `json:"subscriptions"` Errors []string `json:"errors"` + Statuses map[string]string `json:"statuses,omitempty"` } // PromoCode 注册优惠码 diff --git a/backend/internal/handler/failover_loop.go b/backend/internal/handler/failover_loop.go new file mode 100644 index 00000000..b2583301 --- /dev/null +++ b/backend/internal/handler/failover_loop.go @@ -0,0 +1,174 @@ +package handler + +import ( + "context" + "net/http" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/service" + "go.uber.org/zap" +) + +// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。 +// GatewayService 隐式实现此接口。 +type TempUnscheduler interface { + TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) +} + +// FailoverAction 表示 failover 错误处理后的下一步动作 +type FailoverAction int + +const ( + // FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue) + FailoverContinue FailoverAction = iota + // FailoverExhausted 切换次数耗尽(调用方应返回错误响应) + FailoverExhausted + // FailoverCanceled context 已取消(调用方应直接 return) + FailoverCanceled +) + +const ( + // maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误) + maxSameAccountRetries = 2 + // sameAccountRetryDelay 同账号重试间隔 + sameAccountRetryDelay = 500 * time.Millisecond + // singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。 + // Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s), + // Handler 层只需短暂间隔后重新进入 Service 层即可。 + singleAccountBackoffDelay = 2 * time.Second +) + +// FailoverState 跨循环迭代共享的 failover 状态 +type FailoverState struct { + SwitchCount int + MaxSwitches int + FailedAccountIDs map[int64]struct{} + SameAccountRetryCount map[int64]int + LastFailoverErr *service.UpstreamFailoverError + ForceCacheBilling bool + hasBoundSession bool +} + +// NewFailoverState 创建 failover 状态 +func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState { + return &FailoverState{ + MaxSwitches: maxSwitches, + FailedAccountIDs: make(map[int64]struct{}), + SameAccountRetryCount: make(map[int64]int), + hasBoundSession: hasBoundSession, + } +} + +// HandleFailoverError 处理 UpstreamFailoverError,返回下一步动作。 +// 包含:缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。 +func (s *FailoverState) HandleFailoverError( + ctx context.Context, + gatewayService TempUnscheduler, + accountID int64, + platform string, + failoverErr *service.UpstreamFailoverError, +) FailoverAction { + s.LastFailoverErr = failoverErr + + // 缓存计费判断 + if needForceCacheBilling(s.hasBoundSession, failoverErr) { + s.ForceCacheBilling = true + } + + // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 + if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries { + s.SameAccountRetryCount[accountID]++ + logger.FromContext(ctx).Warn("gateway.failover_same_account_retry", + zap.Int64("account_id", accountID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("same_account_retry_count", s.SameAccountRetryCount[accountID]), + zap.Int("same_account_retry_max", maxSameAccountRetries), + ) + if !sleepWithContext(ctx, sameAccountRetryDelay) { + return FailoverCanceled + } + return FailoverContinue + } + + // 同账号重试用尽,执行临时封禁 + if failoverErr.RetryableOnSameAccount { + gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr) + } + + // 加入失败列表 + s.FailedAccountIDs[accountID] = struct{}{} + + // 检查是否耗尽 + if s.SwitchCount >= s.MaxSwitches { + return FailoverExhausted + } + + // 递增切换计数 + s.SwitchCount++ + logger.FromContext(ctx).Warn("gateway.failover_switch_account", + zap.Int64("account_id", accountID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", s.SwitchCount), + zap.Int("max_switches", s.MaxSwitches), + ) + + // Antigravity 平台换号线性递增延时 + if platform == service.PlatformAntigravity { + delay := time.Duration(s.SwitchCount-1) * time.Second + if !sleepWithContext(ctx, delay) { + return FailoverCanceled + } + } + + return FailoverContinue +} + +// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。 +// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景: +// 清除排除列表、等待退避后重新选号。 +// +// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。 +// 返回 FailoverExhausted 时,调用方应返回错误响应。 +// 返回 FailoverCanceled 时,调用方应直接 return。 +func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction { + if s.LastFailoverErr != nil && + s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable && + s.SwitchCount <= s.MaxSwitches { + + logger.FromContext(ctx).Warn("gateway.failover_single_account_backoff", + zap.Duration("backoff_delay", singleAccountBackoffDelay), + zap.Int("switch_count", s.SwitchCount), + zap.Int("max_switches", s.MaxSwitches), + ) + if !sleepWithContext(ctx, singleAccountBackoffDelay) { + return FailoverCanceled + } + logger.FromContext(ctx).Warn("gateway.failover_single_account_retry", + zap.Int("switch_count", s.SwitchCount), + zap.Int("max_switches", s.MaxSwitches), + ) + s.FailedAccountIDs = make(map[int64]struct{}) + return FailoverContinue + } + return FailoverExhausted +} + +// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。 +// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。 +func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool { + return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling) +} + +// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。 +func sleepWithContext(ctx context.Context, d time.Duration) bool { + if d <= 0 { + return true + } + select { + case <-ctx.Done(): + return false + case <-time.After(d): + return true + } +} diff --git a/backend/internal/handler/failover_loop_test.go b/backend/internal/handler/failover_loop_test.go new file mode 100644 index 00000000..5a41b2dd --- /dev/null +++ b/backend/internal/handler/failover_loop_test.go @@ -0,0 +1,732 @@ +package handler + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mock +// --------------------------------------------------------------------------- + +// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。 +type mockTempUnscheduler struct { + calls []tempUnscheduleCall +} + +type tempUnscheduleCall struct { + accountID int64 + failoverErr *service.UpstreamFailoverError +} + +func (m *mockTempUnscheduler) TempUnscheduleRetryableError(_ context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) { + m.calls = append(m.calls, tempUnscheduleCall{accountID: accountID, failoverErr: failoverErr}) +} + +// --------------------------------------------------------------------------- +// Helper +// --------------------------------------------------------------------------- + +func newTestFailoverErr(statusCode int, retryable, forceBilling bool) *service.UpstreamFailoverError { + return &service.UpstreamFailoverError{ + StatusCode: statusCode, + RetryableOnSameAccount: retryable, + ForceCacheBilling: forceBilling, + } +} + +// --------------------------------------------------------------------------- +// NewFailoverState 测试 +// --------------------------------------------------------------------------- + +func TestNewFailoverState(t *testing.T) { + t.Run("初始化字段正确", func(t *testing.T) { + fs := NewFailoverState(5, true) + require.Equal(t, 5, fs.MaxSwitches) + require.Equal(t, 0, fs.SwitchCount) + require.NotNil(t, fs.FailedAccountIDs) + require.Empty(t, fs.FailedAccountIDs) + require.NotNil(t, fs.SameAccountRetryCount) + require.Empty(t, fs.SameAccountRetryCount) + require.Nil(t, fs.LastFailoverErr) + require.False(t, fs.ForceCacheBilling) + require.True(t, fs.hasBoundSession) + }) + + t.Run("无绑定会话", func(t *testing.T) { + fs := NewFailoverState(3, false) + require.Equal(t, 3, fs.MaxSwitches) + require.False(t, fs.hasBoundSession) + }) + + t.Run("零最大切换次数", func(t *testing.T) { + fs := NewFailoverState(0, false) + require.Equal(t, 0, fs.MaxSwitches) + }) +} + +// --------------------------------------------------------------------------- +// sleepWithContext 测试 +// --------------------------------------------------------------------------- + +func TestSleepWithContext(t *testing.T) { + t.Run("零时长立即返回true", func(t *testing.T) { + start := time.Now() + ok := sleepWithContext(context.Background(), 0) + require.True(t, ok) + require.Less(t, time.Since(start), 50*time.Millisecond) + }) + + t.Run("负时长立即返回true", func(t *testing.T) { + start := time.Now() + ok := sleepWithContext(context.Background(), -1*time.Second) + require.True(t, ok) + require.Less(t, time.Since(start), 50*time.Millisecond) + }) + + t.Run("正常等待后返回true", func(t *testing.T) { + start := time.Now() + ok := sleepWithContext(context.Background(), 50*time.Millisecond) + elapsed := time.Since(start) + require.True(t, ok) + require.GreaterOrEqual(t, elapsed, 40*time.Millisecond) + require.Less(t, elapsed, 500*time.Millisecond) + }) + + t.Run("已取消context立即返回false", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + ok := sleepWithContext(ctx, 5*time.Second) + require.False(t, ok) + require.Less(t, time.Since(start), 50*time.Millisecond) + }) + + t.Run("等待期间context取消返回false", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(30 * time.Millisecond) + cancel() + }() + + start := time.Now() + ok := sleepWithContext(ctx, 5*time.Second) + elapsed := time.Since(start) + require.False(t, ok) + require.Less(t, elapsed, 500*time.Millisecond) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 基本切换流程 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_BasicSwitch(t *testing.T) { + t.Run("非重试错误_非Antigravity_直接切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + require.Equal(t, err, fs.LastFailoverErr) + require.False(t, fs.ForceCacheBilling) + require.Empty(t, mock.calls, "不应调用 TempUnschedule") + }) + + t.Run("非重试错误_Antigravity_第一次切换无延迟", func(t *testing.T) { + // switchCount 从 0→1 时,sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0 + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) + + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟应为 0") + }) + + t.Run("非重试错误_Antigravity_第二次切换有1秒延迟", func(t *testing.T) { + // switchCount 从 1→2 时,sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + fs.SwitchCount = 1 // 模拟已切换一次 + + err := newTestFailoverErr(500, false, false) + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SwitchCount) + require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟应约 1s") + require.Less(t, elapsed, 3*time.Second) + }) + + t.Run("连续切换直到耗尽", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(2, false) + + // 第一次切换:0→1 + err1 := newTestFailoverErr(500, false, false) + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + + // 第二次切换:1→2 + err2 := newTestFailoverErr(502, false, false) + action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SwitchCount) + + // 第三次已耗尽:SwitchCount(2) >= MaxSwitches(2) + err3 := newTestFailoverErr(503, false, false) + action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3) + require.Equal(t, FailoverExhausted, action) + require.Equal(t, 2, fs.SwitchCount, "耗尽时不应继续递增") + + // 验证失败账号列表 + require.Len(t, fs.FailedAccountIDs, 3) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + require.Contains(t, fs.FailedAccountIDs, int64(200)) + require.Contains(t, fs.FailedAccountIDs, int64(300)) + + // LastFailoverErr 应为最后一次的错误 + require.Equal(t, err3, fs.LastFailoverErr) + }) + + t.Run("MaxSwitches为0时首次即耗尽", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(0, false) + err := newTestFailoverErr(500, false, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverExhausted, action) + require.Equal(t, 0, fs.SwitchCount) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 缓存计费 (ForceCacheBilling) +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_CacheBilling(t *testing.T) { + t.Run("hasBoundSession为true时设置ForceCacheBilling", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, true) // hasBoundSession=true + err := newTestFailoverErr(500, false, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.True(t, fs.ForceCacheBilling) + }) + + t.Run("failoverErr.ForceCacheBilling为true时设置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, true) // ForceCacheBilling=true + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.True(t, fs.ForceCacheBilling) + }) + + t.Run("两者均为false时不设置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.False(t, fs.ForceCacheBilling) + }) + + t.Run("一旦设置不会被后续错误重置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + // 第一次:ForceCacheBilling=true → 设置 + err1 := newTestFailoverErr(500, false, true) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.True(t, fs.ForceCacheBilling) + + // 第二次:ForceCacheBilling=false → 仍然保持 true + err2 := newTestFailoverErr(502, false, false) + fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.True(t, fs.ForceCacheBilling, "ForceCacheBilling 一旦设置不应被重置") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 同账号重试 (RetryableOnSameAccount) +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_SameAccountRetry(t *testing.T) { + t.Run("第一次重试返回FailoverContinue", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + require.Equal(t, 0, fs.SwitchCount, "同账号重试不应增加切换计数") + require.NotContains(t, fs.FailedAccountIDs, int64(100), "同账号重试不应加入失败列表") + require.Empty(t, mock.calls, "同账号重试期间不应调用 TempUnschedule") + // 验证等待了 sameAccountRetryDelay (500ms) + require.GreaterOrEqual(t, elapsed, 400*time.Millisecond) + require.Less(t, elapsed, 2*time.Second) + }) + + t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + // 第一次 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + + // 第二次 + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SameAccountRetryCount[100]) + + require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule") + }) + + t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + // 第一次、第二次重试 + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, 2, fs.SameAccountRetryCount[100]) + + // 第三次:重试已达到 maxSameAccountRetries(2),应切换账号 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + + // 验证 TempUnschedule 被调用 + require.Len(t, mock.calls, 1) + require.Equal(t, int64(100), mock.calls[0].accountID) + require.Equal(t, err, mock.calls[0].failoverErr) + }) + + t.Run("不同账号独立跟踪重试次数", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(5, false) + err := newTestFailoverErr(400, true, false) + + // 账号 100 第一次重试 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + + // 账号 200 第一次重试(独立计数) + action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[200]) + require.Equal(t, 1, fs.SameAccountRetryCount[100], "账号 100 的计数不应受影响") + }) + + t.Run("重试耗尽后再次遇到同账号_直接切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(5, false) + err := newTestFailoverErr(400, true, false) + + // 耗尽账号 100 的重试 + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + // 第三次: 重试耗尽 → 切换 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + + // 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换 + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — TempUnschedule 调用验证 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_TempUnschedule(t *testing.T) { + t.Run("非重试错误不调用TempUnschedule", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) // RetryableOnSameAccount=false + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Empty(t, mock.calls) + }) + + t.Run("重试错误耗尽后调用TempUnschedule_传入正确参数", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(502, true, false) + + // 耗尽重试 + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + + require.Len(t, mock.calls, 1) + require.Equal(t, int64(42), mock.calls[0].accountID) + require.Equal(t, 502, mock.calls[0].failoverErr.StatusCode) + require.True(t, mock.calls[0].failoverErr.RetryableOnSameAccount) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — Context 取消 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_ContextCanceled(t *testing.T) { + t.Run("同账号重试sleep期间context取消", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + start := time.Now() + action := fs.HandleFailoverError(ctx, mock, 100, "openai", err) + elapsed := time.Since(start) + + require.Equal(t, FailoverCanceled, action) + require.Less(t, elapsed, 100*time.Millisecond, "应立即返回") + // 重试计数仍应递增 + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + }) + + t.Run("Antigravity延迟期间context取消", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + fs.SwitchCount = 1 // 下一次 switchCount=2 → delay = 1s + err := newTestFailoverErr(500, false, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + start := time.Now() + action := fs.HandleFailoverError(ctx, mock, 100, service.PlatformAntigravity, err) + elapsed := time.Since(start) + + require.Equal(t, FailoverCanceled, action) + require.Less(t, elapsed, 100*time.Millisecond, "应立即返回而非等待 1s") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — FailedAccountIDs 跟踪 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_FailedAccountIDs(t *testing.T) { + t.Run("切换时添加到失败列表", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + + fs.HandleFailoverError(context.Background(), mock, 200, "openai", newTestFailoverErr(502, false, false)) + require.Contains(t, fs.FailedAccountIDs, int64(200)) + require.Len(t, fs.FailedAccountIDs, 2) + }) + + t.Run("耗尽时也添加到失败列表", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(0, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + require.Equal(t, FailoverExhausted, action) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + }) + + t.Run("同账号重试期间不添加到失败列表", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(400, true, false)) + require.Equal(t, FailoverContinue, action) + require.NotContains(t, fs.FailedAccountIDs, int64(100)) + }) + + t.Run("同一账号多次切换不重复添加", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(5, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + require.Len(t, fs.FailedAccountIDs, 1, "map 天然去重") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — LastFailoverErr 更新 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_LastFailoverErr(t *testing.T) { + t.Run("每次调用都更新LastFailoverErr", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + err1 := newTestFailoverErr(500, false, false) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.Equal(t, err1, fs.LastFailoverErr) + + err2 := newTestFailoverErr(502, false, false) + fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.Equal(t, err2, fs.LastFailoverErr) + }) + + t.Run("同账号重试时也更新LastFailoverErr", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + err := newTestFailoverErr(400, true, false) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, err, fs.LastFailoverErr) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 综合集成场景 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_IntegrationScenario(t *testing.T) { + t.Run("模拟完整failover流程_多账号混合重试与切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, true) // hasBoundSession=true + + // 1. 账号 100 遇到可重试错误,同账号重试 2 次 + retryErr := newTestFailoverErr(400, true, false) + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling") + + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + + // 2. 账号 100 重试耗尽 → TempUnschedule + 切换 + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Len(t, mock.calls, 1) + + // 3. 账号 200 遇到不可重试错误 → 直接切换 + switchErr := newTestFailoverErr(500, false, false) + action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", switchErr) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SwitchCount) + + // 4. 账号 300 遇到不可重试错误 → 再切换 + action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", switchErr) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 3, fs.SwitchCount) + + // 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3) + action = fs.HandleFailoverError(context.Background(), mock, 400, "openai", switchErr) + require.Equal(t, FailoverExhausted, action) + + // 最终状态验证 + require.Equal(t, 3, fs.SwitchCount, "耗尽时不再递增") + require.Len(t, fs.FailedAccountIDs, 4, "4个不同账号都在失败列表中") + require.True(t, fs.ForceCacheBilling) + require.Len(t, mock.calls, 1, "只有账号 100 触发了 TempUnschedule") + }) + + t.Run("模拟Antigravity平台完整流程", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(2, false) + + err := newTestFailoverErr(500, false, false) + + // 第一次切换:delay = 0s + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err) + elapsed := time.Since(start) + require.Equal(t, FailoverContinue, action) + require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟为 0") + + // 第二次切换:delay = 1s + start = time.Now() + action = fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err) + elapsed = time.Since(start) + require.Equal(t, FailoverContinue, action) + require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟约 1s") + + // 第三次:耗尽(无延迟,因为在检查延迟之前就返回了) + start = time.Now() + action = fs.HandleFailoverError(context.Background(), mock, 300, service.PlatformAntigravity, err) + elapsed = time.Since(start) + require.Equal(t, FailoverExhausted, action) + require.Less(t, elapsed, 200*time.Millisecond, "耗尽时不应有延迟") + }) + + t.Run("ForceCacheBilling通过错误标志设置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) // hasBoundSession=false + + // 第一次:ForceCacheBilling=false + err1 := newTestFailoverErr(500, false, false) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.False(t, fs.ForceCacheBilling) + + // 第二次:ForceCacheBilling=true(Antigravity 粘性会话切换) + err2 := newTestFailoverErr(500, false, true) + fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.True(t, fs.ForceCacheBilling, "错误标志应触发 ForceCacheBilling") + + // 第三次:ForceCacheBilling=false,但状态仍保持 true + err3 := newTestFailoverErr(500, false, false) + fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3) + require.True(t, fs.ForceCacheBilling, "不应重置") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 边界条件 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_EdgeCases(t *testing.T) { + t.Run("StatusCode为0的错误也能正常处理", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(0, false, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + }) + + t.Run("AccountID为0也能正常跟踪", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, true, false) + + action := fs.HandleFailoverError(context.Background(), mock, 0, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[0]) + }) + + t.Run("负AccountID也能正常跟踪", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, true, false) + + action := fs.HandleFailoverError(context.Background(), mock, -1, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[-1]) + }) + + t.Run("空平台名称不触发Antigravity延迟", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + fs.SwitchCount = 1 + err := newTestFailoverErr(500, false, false) + + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, "", err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Less(t, elapsed, 200*time.Millisecond, "空平台不应触发 Antigravity 延迟") + }) +} + +// --------------------------------------------------------------------------- +// HandleSelectionExhausted 测试 +// --------------------------------------------------------------------------- + +func TestHandleSelectionExhausted(t *testing.T) { + t.Run("无LastFailoverErr时返回Exhausted", func(t *testing.T) { + fs := NewFailoverState(3, false) + // LastFailoverErr 为 nil + + action := fs.HandleSelectionExhausted(context.Background()) + require.Equal(t, FailoverExhausted, action) + }) + + t.Run("非503错误返回Exhausted", func(t *testing.T) { + fs := NewFailoverState(3, false) + fs.LastFailoverErr = newTestFailoverErr(500, false, false) + + action := fs.HandleSelectionExhausted(context.Background()) + require.Equal(t, FailoverExhausted, action) + }) + + t.Run("503且未耗尽_等待后返回Continue并清除失败列表", func(t *testing.T) { + fs := NewFailoverState(3, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + fs.FailedAccountIDs[100] = struct{}{} + fs.SwitchCount = 1 + + start := time.Now() + action := fs.HandleSelectionExhausted(context.Background()) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Empty(t, fs.FailedAccountIDs, "应清除失败账号列表") + require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "应等待约 2s") + require.Less(t, elapsed, 5*time.Second) + }) + + t.Run("503但SwitchCount已超过MaxSwitches_返回Exhausted", func(t *testing.T) { + fs := NewFailoverState(2, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + fs.SwitchCount = 3 // > MaxSwitches(2) + + start := time.Now() + action := fs.HandleSelectionExhausted(context.Background()) + elapsed := time.Since(start) + + require.Equal(t, FailoverExhausted, action) + require.Less(t, elapsed, 100*time.Millisecond, "不应等待") + }) + + t.Run("503但context已取消_返回Canceled", func(t *testing.T) { + fs := NewFailoverState(3, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + action := fs.HandleSelectionExhausted(ctx) + elapsed := time.Since(start) + + require.Equal(t, FailoverCanceled, action) + require.Less(t, elapsed, 100*time.Millisecond, "应立即返回") + }) + + t.Run("503且SwitchCount等于MaxSwitches_仍可重试", func(t *testing.T) { + fs := NewFailoverState(2, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + fs.SwitchCount = 2 // == MaxSwitches,条件是 <=,仍可重试 + + action := fs.HandleSelectionExhausted(context.Background()) + require.Equal(t, FailoverContinue, action) + }) +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index ca4442e4..8d39c767 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -6,25 +6,33 @@ import ( "encoding/json" "errors" "fmt" - "io" - "log" "net/http" + "strconv" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" + "go.uber.org/zap" ) +const gatewayCompatibilityMetricsLogInterval = 1024 + +var gatewayCompatibilityMetricsLogCounter atomic.Uint64 + // GatewayHandler handles API gateway requests type GatewayHandler struct { gatewayService *service.GatewayService @@ -34,10 +42,14 @@ type GatewayHandler struct { billingCacheService *service.BillingCacheService usageService *service.UsageService apiKeyService *service.APIKeyService + usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper + userMsgQueueHelper *UserMsgQueueHelper maxAccountSwitches int maxAccountSwitchesGemini int + cfg *config.Config + settingService *service.SettingService } // NewGatewayHandler creates a new GatewayHandler @@ -50,8 +62,11 @@ func NewGatewayHandler( billingCacheService *service.BillingCacheService, usageService *service.UsageService, apiKeyService *service.APIKeyService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, + userMsgQueueService *service.UserMessageQueueService, cfg *config.Config, + settingService *service.SettingService, ) *GatewayHandler { pingInterval := time.Duration(0) maxAccountSwitches := 10 @@ -65,6 +80,13 @@ func NewGatewayHandler( maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini } } + + // 初始化用户消息串行队列 helper + var umqHelper *UserMsgQueueHelper + if userMsgQueueService != nil && cfg != nil { + umqHelper = NewUserMsgQueueHelper(userMsgQueueService, SSEPingFormatClaude, pingInterval) + } + return &GatewayHandler{ gatewayService: gatewayService, geminiCompatService: geminiCompatService, @@ -73,10 +95,14 @@ func NewGatewayHandler( billingCacheService: billingCacheService, usageService: usageService, apiKeyService: apiKeyService, + usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), + userMsgQueueHelper: umqHelper, maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, + cfg: cfg, + settingService: settingService, } } @@ -95,9 +121,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } + reqLog := requestLogger( + c, + "handler.gateway.messages", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + defer h.maybeLogCompatibilityFallbackMetrics(reqLog) // 读取请求体 - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -114,27 +148,33 @@ func (h *GatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, "", false, body) - parsedReq, err := service.ParseGatewayRequest(body) + parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } reqModel := parsedReq.Model reqStream := parsedReq.Stream + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { - ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true) + ctx := service.WithIsMaxTokensOneHaikuRequest(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } - // 检查是否为 Claude Code 客户端,设置到 context 中 - SetClaudeCodeClientContext(c, body) + // 检查是否为 Claude Code 客户端,设置到 context 中(复用已解析请求,避免二次反序列化)。 + SetClaudeCodeClientContext(c, body, parsedReq) isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context()) + // 版本检查:仅对 Claude Code 客户端,拒绝低于最低版本的请求 + if !h.checkClaudeCodeVersion(c) { + return + } + // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) + c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) setOpsRequestContext(c, reqModel, reqStream, body) @@ -160,9 +200,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) waitCounted := false if err != nil { - log.Printf("Increment wait count failed: %v", err) + reqLog.Warn("gateway.user_wait_counter_increment_failed", zap.Error(err)) // On error, allow request to proceed } else if !canWait { + reqLog.Info("gateway.user_wait_queue_full", zap.Int("max_wait", maxWait)) h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") return } @@ -179,7 +220,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 1. 首先获取用户并发槽位 userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) if err != nil { - log.Printf("User concurrency acquire failed: %v", err) + reqLog.Warn("gateway.user_slot_acquire_failed", zap.Error(err)) h.handleConcurrencyError(c, err, "user", streamStarted) return } @@ -196,13 +237,18 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 2. 【新增】Wait后二次检查余额/订阅 if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - log.Printf("Billing eligibility check failed after wait: %v", err) + reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err)) status, code, message := billingErrorDetails(err) h.handleStreamingAwareError(c, status, code, message, streamStarted) return } // 计算粘性会话hash + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) // 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台 @@ -221,33 +267,54 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var sessionBoundAccountID int64 if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + if sessionBoundAccountID > 0 { + prefetchedGroupID := int64(0) + if apiKey.GroupID != nil { + prefetchedGroupID = *apiKey.GroupID + } + ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + } } // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 if platform == service.PlatformGemini { - maxAccountSwitches := h.maxAccountSwitchesGemini - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - var lastFailoverErr *service.UpstreamFailoverError - var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 + fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession) + + // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 + // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 + if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) { + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { - if len(failedAccountIDs) == 0 { + if len(fs.FailedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - if lastFailoverErr != nil { - h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted) - } else { - h.handleFailoverExhaustedSimple(c, 502, streamStarted) + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + continue + case FailoverCanceled: + return + default: // FailoverExhausted + if fs.LastFailoverErr != nil { + h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } + return } - return } account := selection.Account - setOpsSelectedAccount(c, account.ID) + setOpsSelectedAccount(c, account.ID, account.Platform) // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { @@ -275,21 +342,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { accountWaitCounted := false canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { - log.Printf("Increment account wait count failed: %v", err) + reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) + reqLog.Info("gateway.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) return } if err == nil && canWait { accountWaitCounted = true } - // Ensure the wait counter is decremented if we exit before acquiring the slot. - defer func() { + releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false } - }() + } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -300,17 +370,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { &streamStarted, ) if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) + reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } // Slot acquired: no longer waiting in queue. - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } + releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) + reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } // 账号槽位/等待计数需要在超时或断开时安全回收 @@ -319,8 +387,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() - if switchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + if fs.SwitchCount > 0 { + requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } if account.Platform == service.PlatformAntigravity { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession) @@ -333,46 +401,62 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - lastFailoverErr = failoverErr - if failoverErr.ForceCacheBilling { - forceCacheBilling = true - } - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted) + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: + continue + case FailoverExhausted: + h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted) + return + case FailoverCanceled: return } - switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) - continue } - // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Forward request failed: %v", err) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } + // RPM 计数递增(Forward 成功后) + // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 + // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 + if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { + if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { + reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + } + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, + UserAgent: userAgent, IPAddress: clientIP, - ForceCacheBilling: fcb, + ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { - log.Printf("Record usage failed: %v", err) + logger.L().With( + zap.String("component", "handler.gateway.messages"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("gateway.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP, forceCacheBilling) + }) return } } @@ -385,31 +469,44 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } fallbackUsed := false + // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 + // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 + if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) { + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + } + for { - maxAccountSwitches := h.maxAccountSwitches - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - var lastFailoverErr *service.UpstreamFailoverError + fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession) retryWithFallback := false - var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 for { // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID) if err != nil { - if len(failedAccountIDs) == 0 { + if len(fs.FailedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - if lastFailoverErr != nil { - h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted) - } else { - h.handleFailoverExhaustedSimple(c, 502, streamStarted) + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + continue + case FailoverCanceled: + return + default: // FailoverExhausted + if fs.LastFailoverErr != nil { + h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } + return } - return } account := selection.Account - setOpsSelectedAccount(c, account.ID) + setOpsSelectedAccount(c, account.ID, account.Platform) // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { @@ -437,20 +534,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { accountWaitCounted := false canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { - log.Printf("Increment account wait count failed: %v", err) + reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) + reqLog.Info("gateway.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) return } if err == nil && canWait { accountWaitCounted = true } - defer func() { + releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false } - }() + } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -461,50 +562,117 @@ func (h *GatewayHandler) Messages(c *gin.Context) { &streamStarted, ) if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) + reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } + // Slot acquired: no longer waiting in queue. + releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) + reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + // ===== 用户消息串行队列 START ===== + var queueRelease func() + umqMode := h.getUserMsgQueueMode(account, parsedReq) + + switch umqMode { + case config.UMQModeSerialize: + // 串行模式:获取锁 + RPM 延迟 + 释放(当前行为不变) + baseRPM := account.GetBaseRPM() + release, qErr := h.userMsgQueueHelper.AcquireWithWait( + c, account.ID, baseRPM, reqStream, &streamStarted, + h.cfg.Gateway.UserMessageQueue.WaitTimeout(), + reqLog, + ) + if qErr != nil { + // fail-open: 记录 warn,不阻止请求 + reqLog.Warn("gateway.umq_acquire_failed", + zap.Int64("account_id", account.ID), + zap.Error(qErr), + ) + } else { + queueRelease = release + } + + case config.UMQModeThrottle: + // 软性限速:仅施加 RPM 自适应延迟,不阻塞并发 + baseRPM := account.GetBaseRPM() + if tErr := h.userMsgQueueHelper.ThrottleWithPing( + c, account.ID, baseRPM, reqStream, &streamStarted, + h.cfg.Gateway.UserMessageQueue.WaitTimeout(), + reqLog, + ); tErr != nil { + reqLog.Warn("gateway.umq_throttle_failed", + zap.Int64("account_id", account.ID), + zap.Error(tErr), + ) + } + + default: + if umqMode != "" { + reqLog.Warn("gateway.umq_unknown_mode", + zap.String("mode", umqMode), + zap.Int64("account_id", account.ID), + ) + } + } + + // 用 wrapReleaseOnDone 确保 context 取消时自动释放(仅 serialize 模式有 queueRelease) + queueRelease = wrapReleaseOnDone(c.Request.Context(), queueRelease) + // 注入回调到 ParsedRequest:使用外层 wrapper 以便提前清理 AfterFunc + parsedReq.OnUpstreamAccepted = queueRelease + // ===== 用户消息串行队列 END ===== + // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() - if switchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + if fs.SwitchCount > 0 { + requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } - if account.Platform == service.PlatformAntigravity { + if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) } else { result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) } + + // 兜底释放串行锁(正常情况已通过回调提前释放) + if queueRelease != nil { + queueRelease() + } + // 清理回调引用,防止 failover 重试时旧回调被错误调用 + parsedReq.OnUpstreamAccepted = nil + if accountReleaseFunc != nil { accountReleaseFunc() } if err != nil { var promptTooLongErr *service.PromptTooLongError if errors.As(err, &promptTooLongErr) { - log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed) + reqLog.Warn("gateway.prompt_too_long_from_antigravity", + zap.Any("current_group_id", currentAPIKey.GroupID), + zap.Any("fallback_group_id", fallbackGroupID), + zap.Bool("fallback_used", fallbackUsed), + ) if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 { fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID) if err != nil { - log.Printf("Resolve fallback group failed: %v", err) + reqLog.Warn("gateway.resolve_fallback_group_failed", zap.Int64("fallback_group_id", *fallbackGroupID), zap.Error(err)) _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) return } if fallbackGroup.Platform != service.PlatformAnthropic || fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription || fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { - log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType) + reqLog.Warn("gateway.fallback_group_invalid", + zap.Int64("fallback_group_id", fallbackGroup.ID), + zap.String("fallback_platform", fallbackGroup.Platform), + zap.String("fallback_subscription_type", fallbackGroup.SubscriptionType), + ) _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) return } @@ -514,7 +682,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleStreamingAwareError(c, status, code, message, streamStarted) return } - // 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度 + // 兜底重试按"直接请求兜底分组"处理:清除强制平台,允许按分组平台调度 ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "") c.Request = c.Request.WithContext(ctx) currentAPIKey = fallbackAPIKey @@ -528,46 +696,62 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - lastFailoverErr = failoverErr - if failoverErr.ForceCacheBilling { - forceCacheBilling = true - } - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted) + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: + continue + case FailoverExhausted: + h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted) + return + case FailoverCanceled: return } - switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) - continue } - // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Account %d: Forward request failed: %v", account.ID, err) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } + // RPM 计数递增(Forward 成功后) + // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 + // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 + if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { + if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { + reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + } + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, APIKey: currentAPIKey, User: currentAPIKey.User, - Account: usedAccount, + Account: account, Subscription: currentSubscription, - UserAgent: ua, + UserAgent: userAgent, IPAddress: clientIP, - ForceCacheBilling: fcb, + ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { - log.Printf("Record usage failed: %v", err) + logger.L().With( + zap.String("component", "handler.gateway.messages"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", currentAPIKey.ID), + zap.Any("group_id", currentAPIKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("gateway.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP, forceCacheBilling) + }) return } if !retryWithFallback { @@ -590,6 +774,17 @@ func (h *GatewayHandler) Models(c *gin.Context) { groupID = &apiKey.Group.ID platform = apiKey.Group.Platform } + if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok && strings.TrimSpace(forcedPlatform) != "" { + platform = forcedPlatform + } + + if platform == service.PlatformSora { + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": service.DefaultSoraModels(h.cfg), + }) + return + } // Get available models from account configurations (without platform filter) availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") @@ -820,6 +1015,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se msg = *rule.CustomMessage } + if rule.SkipMonitoring { + c.Set(service.OpsSkipPassthroughKey, true) + } + h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) return } @@ -859,20 +1058,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Send error event in SSE format with proper JSON marshaling - errorData := map[string]any{ - "type": "error", - "error": map[string]string{ - "type": errType, - "message": message, - }, - } - jsonBytes, err := json.Marshal(errorData) - if err != nil { - _ = c.Error(err) - return - } - errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes)) + // SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。 + errorEvent := `data: {"type":"error","error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n" if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } @@ -885,6 +1072,50 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e h.errorResponse(c, status, errType, message) } +// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 +func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) + return true +} + +// checkClaudeCodeVersion 检查 Claude Code 客户端版本是否满足最低要求 +// 仅对已识别的 Claude Code 客户端执行,count_tokens 路径除外 +func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool { + ctx := c.Request.Context() + if !service.IsClaudeCodeClient(ctx) { + return true + } + + // 排除 count_tokens 子路径 + if strings.HasSuffix(c.Request.URL.Path, "/count_tokens") { + return true + } + + minVersion := h.settingService.GetMinClaudeCodeVersion(ctx) + if minVersion == "" { + return true // 未设置,不检查 + } + + clientVersion := service.GetClaudeCodeVersion(ctx) + if clientVersion == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", + "Unable to determine Claude Code version. Please update Claude Code: npm update -g @anthropic-ai/claude-code") + return false + } + + if service.CompareVersions(clientVersion, minVersion) < 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", + fmt.Sprintf("Your Claude Code version (%s) is below the minimum required version (%s). Please update: npm update -g @anthropic-ai/claude-code", + clientVersion, minVersion)) + return false + } + + return true +} + // errorResponse 返回Claude API格式的错误响应 func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ @@ -912,9 +1143,16 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } + reqLog := requestLogger( + c, + "handler.gateway.count_tokens", + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + defer h.maybeLogCompatibilityFallbackMetrics(reqLog) // 读取请求体 - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -929,18 +1167,18 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } - // 检查是否为 Claude Code 客户端,设置到 context 中 - SetClaudeCodeClientContext(c, body) - setOpsRequestContext(c, "", false, body) - parsedReq, err := service.ParseGatewayRequest(body) + parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } + // count_tokens 走 messages 严格校验时,复用已解析请求,避免二次反序列化。 + SetClaudeCodeClientContext(c, body, parsedReq) + reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream)) // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) + c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) // 验证 model 必填 if parsedReq.Model == "" { @@ -962,19 +1200,25 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { } // 计算粘性会话 hash + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) // 选择支持该模型的账号 account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) if err != nil { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) + reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err)) + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable") return } - setOpsSelectedAccount(c, account.ID) + setOpsSelectedAccount(c, account.ID, account.Platform) // 转发请求(不记录使用量) if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil { - log.Printf("Forward count_tokens request failed: %v", err) + reqLog.Error("gateway.count_tokens_forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) // 错误响应已在 ForwardCountTokens 中处理 return } @@ -1098,24 +1342,8 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce textDeltas = []string{"New", " Conversation"} } - // Build message_start event with proper JSON marshaling - messageStart := map[string]any{ - "type": "message_start", - "message": map[string]any{ - "id": msgID, - "type": "message", - "role": "assistant", - "model": model, - "content": []any{}, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]int{ - "input_tokens": 10, - "output_tokens": 0, - }, - }, - } - messageStartJSON, _ := json.Marshal(messageStart) + // Build message_start event with fixed schema. + messageStartJSON := `{"type":"message_start","message":{"id":` + strconv.Quote(msgID) + `,"type":"message","role":"assistant","model":` + strconv.Quote(model) + `,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}` // Build events events := []string{ @@ -1125,31 +1353,12 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce // Add text deltas for _, text := range textDeltas { - delta := map[string]any{ - "type": "content_block_delta", - "index": 0, - "delta": map[string]string{ - "type": "text_delta", - "text": text, - }, - } - deltaJSON, _ := json.Marshal(delta) + deltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + strconv.Quote(text) + `}}` events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON)) } // Add final events - messageDelta := map[string]any{ - "type": "message_delta", - "delta": map[string]any{ - "stop_reason": "end_turn", - "stop_sequence": nil, - }, - "usage": map[string]int{ - "input_tokens": 10, - "output_tokens": outputTokens, - }, - } - messageDeltaJSON, _ := json.Marshal(messageDelta) + messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":` + strconv.Itoa(outputTokens) + `}}` events = append(events, `event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`, @@ -1238,7 +1447,78 @@ func billingErrorDetails(err error) (status int, code, message string) { } msg := pkgerrors.Message(err) if msg == "" { - msg = err.Error() + logger.L().With( + zap.String("component", "handler.gateway.billing"), + zap.Error(err), + ).Warn("gateway.billing_error_missing_message") + msg = "Billing error" } return http.StatusForbidden, "billing_error", msg } + +func (h *GatewayHandler) metadataBridgeEnabled() bool { + if h == nil || h.cfg == nil { + return true + } + return h.cfg.Gateway.OpenAIWS.MetadataBridgeEnabled +} + +func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger) { + if reqLog == nil { + return + } + if gatewayCompatibilityMetricsLogCounter.Add(1)%gatewayCompatibilityMetricsLogInterval != 0 { + return + } + metrics := service.SnapshotOpenAICompatibilityFallbackMetrics() + reqLog.Info("gateway.compatibility_fallback_metrics", + zap.Int64("session_hash_legacy_read_fallback_total", metrics.SessionHashLegacyReadFallbackTotal), + zap.Int64("session_hash_legacy_read_fallback_hit", metrics.SessionHashLegacyReadFallbackHit), + zap.Int64("session_hash_legacy_dual_write_total", metrics.SessionHashLegacyDualWriteTotal), + zap.Float64("session_hash_legacy_read_hit_rate", metrics.SessionHashLegacyReadHitRate), + zap.Int64("metadata_legacy_fallback_total", metrics.MetadataLegacyFallbackTotal), + ) +} + +func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "handler.gateway.messages"), + zap.Any("panic", recovered), + ).Error("gateway.usage_record_task_panic_recovered") + } + }() + task(ctx) +} + +// getUserMsgQueueMode 获取当前请求的 UMQ 模式 +// 返回 "serialize" | "throttle" | "" +func (h *GatewayHandler) getUserMsgQueueMode(account *service.Account, parsed *service.ParsedRequest) string { + if h.userMsgQueueHelper == nil { + return "" + } + // 仅适用于 Anthropic OAuth/SetupToken 账号 + if !account.IsAnthropicOAuthOrSetupToken() { + return "" + } + if !service.IsRealUserMessage(parsed) { + return "" + } + // 账号级模式优先,fallback 到全局配置 + mode := account.GetUserMsgQueueMode() + if mode == "" { + mode = h.cfg.Gateway.UserMessageQueue.GetEffectiveMode() + } + return mode +} diff --git a/backend/internal/handler/gateway_handler_error_fallback_test.go b/backend/internal/handler/gateway_handler_error_fallback_test.go new file mode 100644 index 00000000..4fce5ec1 --- /dev/null +++ b/backend/internal/handler/gateway_handler_error_fallback_test.go @@ -0,0 +1,49 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &GatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.True(t, wrote) + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + assert.Equal(t, "error", parsed["type"]) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusTeapot, "already written") + + h := &GatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.False(t, wrote) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go new file mode 100644 index 00000000..2afa6440 --- /dev/null +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -0,0 +1,348 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "encoding/json" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// 目标:严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”, +// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时, +// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。 + +type fakeSchedulerCache struct { + accounts []*service.Account +} + +func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) { + return f.accounts, true, nil +} +func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error { + return nil +} +func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) { + return nil, nil +} +func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil } +func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil } +func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error { + return nil +} +func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) { + return true, nil +} +func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) { + return nil, nil +} +func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil } +func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil } + +type fakeGroupRepo struct { + group *service.Group +} + +func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil } +func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) { + return f.group, nil +} +func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) { + return f.group, nil +} +func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil } +func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil } +func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil } +func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil } +func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) { + return nil, nil +} +func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil } +func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil } +func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) { + return nil, nil +} +func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil } +func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error { + return nil +} + +type fakeConcurrencyCache struct{} + +func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) { + return true, nil +} +func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil } +func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) { + return 0, nil +} +func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) { + return true, nil +} +func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil } +func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) { + return 0, nil +} +func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) { + return true, nil +} +func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil } +func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil } +func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) { + return true, nil +} +func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil } +func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + return map[int64]*service.AccountLoadInfo{}, nil +} +func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + return map[int64]*service.UserLoadInfo{}, nil +} +func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, id := range accountIDs { + result[id] = 0 + } + return result, nil +} +func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil } + +func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) { + t.Helper() + + schedulerCache := &fakeSchedulerCache{accounts: accounts} + schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil) + + gwSvc := service.NewGatewayService( + nil, // accountRepo (not used: scheduler snapshot hit) + &fakeGroupRepo{group: group}, + nil, // usageLogRepo + nil, // userRepo + nil, // userSubRepo + nil, // userGroupRateRepo + nil, // cache (disable sticky) + nil, // cfg + schedulerSnapshot, + nil, // concurrencyService (disable load-aware; tryAcquire always acquired) + nil, // billingService + nil, // rateLimitService + nil, // billingCacheService + nil, // identityService + nil, // httpUpstream + nil, // deferredService + nil, // claudeTokenProvider + nil, // sessionLimitCache + nil, // rpmCache + nil, // digestStore + ) + + // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 + cfg := &config.Config{RunMode: config.RunModeSimple} + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg) + + concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{}) + concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0) + + h := &GatewayHandler{ + gatewayService: gwSvc, + billingCacheService: billingCacheSvc, + concurrencyHelper: concurrencyHelper, + // 这些字段对本测试不敏感,保持较小即可 + maxAccountSwitches: 1, + maxAccountSwitchesGemini: 1, + } + + cleanup := func() { + billingCacheSvc.Stop() + } + return h, cleanup +} + +func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(2001) + accountID := int64(1001) + + group := &service.Group{ + ID: groupID, + Hydrated: true, + Platform: service.PlatformAnthropic, // /v1/messages(Claude兼容)入口 + Status: service.StatusActive, + } + + account := &service.Account{ + ID: accountID, + Name: "ag-1", + Platform: service.PlatformAntigravity, + Type: service.AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "tok_xxx", + "intercept_warmup_requests": true, + }, + Extra: map[string]any{ + "mixed_scheduling": true, // 关键:允许被 anthropic 分组混合调度选中 + }, + Concurrency: 1, + Priority: 1, + Status: service.StatusActive, + Schedulable: true, + AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}}, + } + + h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account}) + defer cleanup() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + body := []byte(`{ + "model": "claude-sonnet-4-5", + "max_tokens": 256, + "messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}] + }`) + req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group)) + c.Request = req + + apiKey := &service.APIKey{ + ID: 3001, + UserID: 4001, + GroupID: &groupID, + Status: service.StatusActive, + User: &service.User{ + ID: 4001, + Concurrency: 10, + Balance: 100, + }, + Group: group, + } + + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10}) + + h.Messages(c) + + require.Equal(t, 200, rec.Code) + + // 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果) + selected, ok := c.Get(opsAccountIDKey) + require.True(t, ok) + require.Equal(t, accountID, selected) + + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "msg_mock_warmup", resp["id"]) + require.Equal(t, "claude-sonnet-4-5", resp["model"]) + + content, ok := resp["content"].([]any) + require.True(t, ok) + require.Len(t, content, 1) + first, ok := content[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "New Conversation", first["text"]) +} + +func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(2002) + accountID := int64(1002) + + group := &service.Group{ + ID: groupID, + Hydrated: true, + Platform: service.PlatformAntigravity, + Status: service.StatusActive, + } + + account := &service.Account{ + ID: accountID, + Name: "ag-2", + Platform: service.PlatformAntigravity, + Type: service.AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "tok_xxx", + "intercept_warmup_requests": true, + }, + Concurrency: 1, + Priority: 1, + Status: service.StatusActive, + Schedulable: true, + AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}}, + } + + h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account}) + defer cleanup() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + body := []byte(`{ + "model": "claude-sonnet-4-5", + "max_tokens": 256, + "messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}] + }`) + req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + // 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果: + // - 写入 request.Context(Service读取) + // - 写入 gin.Context(Handler快速读取) + ctx := context.WithValue(req.Context(), ctxkey.Group, group) + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity) + req = req.WithContext(ctx) + c.Request = req + c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity) + + apiKey := &service.APIKey{ + ID: 3002, + UserID: 4002, + GroupID: &groupID, + Status: service.StatusActive, + User: &service.User{ + ID: 4002, + Concurrency: 10, + Balance: 100, + }, + Group: group, + } + + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10}) + + h.Messages(c) + + require.Equal(t, 200, rec.Code) + + selected, ok := c.Get(opsAccountIDKey) + require.True(t, ok) + require.Equal(t, accountID, selected) + + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "msg_mock_warmup", resp["id"]) + require.Equal(t, "claude-sonnet-4-5", resp["model"]) +} diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 0393f954..09e6c09b 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -4,8 +4,9 @@ import ( "context" "encoding/json" "fmt" - "math/rand" + "math/rand/v2" "net/http" + "strings" "sync" "time" @@ -17,23 +18,91 @@ import ( // claudeCodeValidator is a singleton validator for Claude Code client detection var claudeCodeValidator = service.NewClaudeCodeValidator() +const claudeCodeParsedRequestContextKey = "claude_code_parsed_request" + // SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中 // 返回更新后的 context -func SetClaudeCodeClientContext(c *gin.Context, body []byte) { - // 解析请求体为 map - var bodyMap map[string]any - if len(body) > 0 { - _ = json.Unmarshal(body, &bodyMap) +func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.ParsedRequest) { + if c == nil || c.Request == nil { + return + } + if parsedReq != nil { + c.Set(claudeCodeParsedRequestContextKey, parsedReq) } - // 验证是否为 Claude Code 客户端 - isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap) + ua := c.GetHeader("User-Agent") + // Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。 + if !claudeCodeValidator.ValidateUserAgent(ua) { + ctx := service.SetClaudeCodeClient(c.Request.Context(), false) + c.Request = c.Request.WithContext(ctx) + return + } + + isClaudeCode := false + if !strings.Contains(c.Request.URL.Path, "messages") { + // 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。 + isClaudeCode = true + } else { + // 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。 + bodyMap := claudeCodeBodyMapFromParsedRequest(parsedReq) + if bodyMap == nil { + bodyMap = claudeCodeBodyMapFromContextCache(c) + } + if bodyMap == nil && len(body) > 0 { + _ = json.Unmarshal(body, &bodyMap) + } + isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap) + } // 更新 request context ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode) + + // 仅在确认为 Claude Code 客户端时提取版本号写入 context + if isClaudeCode { + if version := claudeCodeValidator.ExtractVersion(ua); version != "" { + ctx = service.SetClaudeCodeVersion(ctx, version) + } + } + c.Request = c.Request.WithContext(ctx) } +func claudeCodeBodyMapFromParsedRequest(parsedReq *service.ParsedRequest) map[string]any { + if parsedReq == nil { + return nil + } + bodyMap := map[string]any{ + "model": parsedReq.Model, + } + if parsedReq.System != nil || parsedReq.HasSystem { + bodyMap["system"] = parsedReq.System + } + if parsedReq.MetadataUserID != "" { + bodyMap["metadata"] = map[string]any{"user_id": parsedReq.MetadataUserID} + } + return bodyMap +} + +func claudeCodeBodyMapFromContextCache(c *gin.Context) map[string]any { + if c == nil { + return nil + } + if cached, ok := c.Get(service.OpenAIParsedRequestBodyKey); ok { + if bodyMap, ok := cached.(map[string]any); ok { + return bodyMap + } + } + if cached, ok := c.Get(claudeCodeParsedRequestContextKey); ok { + switch v := cached.(type) { + case *service.ParsedRequest: + return claudeCodeBodyMapFromParsedRequest(v) + case service.ParsedRequest: + return claudeCodeBodyMapFromParsedRequest(&v) + } + } + return nil +} + // 并发槽位等待相关常量 // // 性能优化说明: @@ -104,31 +173,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo // wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation. // 用于避免客户端断开或上游超时导致的并发槽位泄漏。 -// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露 +// 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。 func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() { if releaseFunc == nil { return nil } var once sync.Once - quit := make(chan struct{}) + var stop func() bool release := func() { once.Do(func() { + if stop != nil { + _ = stop() + } releaseFunc() - close(quit) // 通知监听 goroutine 退出 }) } - go func() { - select { - case <-ctx.Done(): - // Context 取消时释放资源 - release() - case <-quit: - // 正常释放已完成,goroutine 退出 - return - } - }() + stop = context.AfterFunc(ctx, release) return release } @@ -153,6 +215,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou h.concurrencyService.DecrementAccountWaitCount(ctx, accountID) } +// TryAcquireUserSlot 尝试立即获取用户并发槽位。 +// 返回值: (releaseFunc, acquired, error) +func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) { + result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency) + if err != nil { + return nil, false, err + } + if !result.Acquired { + return nil, false, nil + } + return result.ReleaseFunc, true, nil +} + +// TryAcquireAccountSlot 尝试立即获取账号并发槽位。 +// 返回值: (releaseFunc, acquired, error) +func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) { + result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) + if err != nil { + return nil, false, err + } + if !result.Acquired { + return nil, false, nil + } + return result.ReleaseFunc, true, nil +} + // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // For streaming requests, sends ping events during the wait. // streamStarted is updated if streaming response has begun. @@ -160,13 +248,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64 ctx := c.Request.Context() // Try to acquire immediately - result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency) + releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency) if err != nil { return nil, err } - if result.Acquired { - return result.ReleaseFunc, nil + if acquired { + return releaseFunc, nil } // Need to wait - handle streaming ping if needed @@ -180,13 +268,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID ctx := c.Request.Context() // Try to acquire immediately - result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) + releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency) if err != nil { return nil, err } - if result.Acquired { - return result.ReleaseFunc, nil + if acquired { + return releaseFunc, nil } // Need to wait - handle streaming ping if needed @@ -196,27 +284,29 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // streamStarted pointer is updated when streaming begins (for proper error handling by caller). func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { - return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted) + return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false) } // waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout. -func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { +func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) { ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() - // Try immediate acquire first (avoid unnecessary wait) - var result *service.AcquireResult - var err error - if slotType == "user" { - result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) - } else { - result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) + acquireSlot := func() (*service.AcquireResult, error) { + if slotType == "user" { + return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) + } + return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) } - if err != nil { - return nil, err - } - if result.Acquired { - return result.ReleaseFunc, nil + + if tryImmediate { + result, err := acquireSlot() + if err != nil { + return nil, err + } + if result.Acquired { + return result.ReleaseFunc, nil + } } // Determine if ping is needed (streaming + ping format defined) @@ -242,7 +332,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType backoff := initialBackoff timer := time.NewTimer(backoff) defer timer.Stop() - rng := rand.New(rand.NewSource(time.Now().UnixNano())) for { select { @@ -268,15 +357,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType case <-timer.C: // Try to acquire slot - var result *service.AcquireResult - var err error - - if slotType == "user" { - result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) - } else { - result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) - } - + result, err := acquireSlot() if err != nil { return nil, err } @@ -284,7 +365,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType if result.Acquired { return result.ReleaseFunc, nil } - backoff = nextBackoff(backoff, rng) + backoff = nextBackoff(backoff) timer.Reset(backoff) } } @@ -292,26 +373,22 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType // AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping). func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { - return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted) + return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true) } // nextBackoff 计算下一次退避时间 // 性能优化:使用指数退避 + 随机抖动,避免惊群效应 // current: 当前退避时间 -// rng: 随机数生成器(可为 nil,此时不添加抖动) // 返回值:下一次退避时间(100ms ~ 2s 之间) -func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration { +func nextBackoff(current time.Duration) time.Duration { // 指数退避:当前时间 * 1.5 next := time.Duration(float64(current) * backoffMultiplier) if next > maxBackoff { next = maxBackoff } - if rng == nil { - return next - } // 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2) // 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis - jitter := 0.8 + rng.Float64()*0.4 + jitter := 0.8 + rand.Float64()*0.4 jittered := time.Duration(float64(next) * jitter) if jittered < initialBackoff { return initialBackoff diff --git a/backend/internal/handler/gateway_helper_backoff_test.go b/backend/internal/handler/gateway_helper_backoff_test.go new file mode 100644 index 00000000..a5056bbb --- /dev/null +++ b/backend/internal/handler/gateway_helper_backoff_test.go @@ -0,0 +1,106 @@ +package handler + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Task 6.2 验证: math/rand/v2 迁移后 nextBackoff 行为正确 --- + +func TestNextBackoff_ExponentialGrowth(t *testing.T) { + // 验证退避时间指数增长(乘数 1.5) + // 由于有随机抖动(±20%),需要验证范围 + current := initialBackoff // 100ms + + for i := 0; i < 10; i++ { + next := nextBackoff(current) + + // 退避结果应在 [initialBackoff, maxBackoff] 范围内 + assert.GreaterOrEqual(t, int64(next), int64(initialBackoff), + "第 %d 次退避不应低于初始值 %v", i, initialBackoff) + assert.LessOrEqual(t, int64(next), int64(maxBackoff), + "第 %d 次退避不应超过最大值 %v", i, maxBackoff) + + // 为下一轮提供当前退避值 + current = next + } +} + +func TestNextBackoff_BoundedByMaxBackoff(t *testing.T) { + // 即使输入非常大,输出也不超过 maxBackoff + for i := 0; i < 100; i++ { + result := nextBackoff(10 * time.Second) + assert.LessOrEqual(t, int64(result), int64(maxBackoff), + "退避值不应超过 maxBackoff") + } +} + +func TestNextBackoff_BoundedByInitialBackoff(t *testing.T) { + // 即使输入非常小,输出也不低于 initialBackoff + for i := 0; i < 100; i++ { + result := nextBackoff(1 * time.Millisecond) + assert.GreaterOrEqual(t, int64(result), int64(initialBackoff), + "退避值不应低于 initialBackoff") + } +} + +func TestNextBackoff_HasJitter(t *testing.T) { + // 验证多次调用会产生不同的值(随机抖动生效) + // 使用相同的输入调用 50 次,收集结果 + results := make(map[time.Duration]bool) + current := 500 * time.Millisecond + + for i := 0; i < 50; i++ { + result := nextBackoff(current) + results[result] = true + } + + // 50 次调用应该至少有 2 个不同的值(抖动存在) + require.Greater(t, len(results), 1, + "nextBackoff 应产生随机抖动,但所有 50 次调用结果相同") +} + +func TestNextBackoff_InitialValueGrows(t *testing.T) { + // 验证从初始值开始,退避趋势是增长的 + current := initialBackoff + var sum time.Duration + + runs := 100 + for i := 0; i < runs; i++ { + next := nextBackoff(current) + sum += next + current = next + } + + avg := sum / time.Duration(runs) + // 平均退避时间应大于初始值(因为指数增长 + 上限) + assert.Greater(t, int64(avg), int64(initialBackoff), + "平均退避时间应大于初始退避值") +} + +func TestNextBackoff_ConvergesToMaxBackoff(t *testing.T) { + // 从初始值开始,经过多次退避后应收敛到 maxBackoff 附近 + current := initialBackoff + for i := 0; i < 20; i++ { + current = nextBackoff(current) + } + + // 经过 20 次迭代后,应该已经到达 maxBackoff 区间 + // 由于抖动,允许 ±20% 的范围 + lowerBound := time.Duration(float64(maxBackoff) * 0.8) + assert.GreaterOrEqual(t, int64(current), int64(lowerBound), + "经过多次退避后应收敛到 maxBackoff 附近") +} + +func BenchmarkNextBackoff(b *testing.B) { + current := initialBackoff + for i := 0; i < b.N; i++ { + current = nextBackoff(current) + if current > maxBackoff { + current = initialBackoff + } + } +} diff --git a/backend/internal/handler/gateway_helper_fastpath_test.go b/backend/internal/handler/gateway_helper_fastpath_test.go new file mode 100644 index 00000000..31d489f0 --- /dev/null +++ b/backend/internal/handler/gateway_helper_fastpath_test.go @@ -0,0 +1,122 @@ +package handler + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +type concurrencyCacheMock struct { + acquireUserSlotFn func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) + acquireAccountSlotFn func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) + releaseUserCalled int32 + releaseAccountCalled int32 +} + +func (m *concurrencyCacheMock) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + if m.acquireAccountSlotFn != nil { + return m.acquireAccountSlotFn(ctx, accountID, maxConcurrency, requestID) + } + return false, nil +} + +func (m *concurrencyCacheMock) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + atomic.AddInt32(&m.releaseAccountCalled, 1) + return nil +} + +func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (m *concurrencyCacheMock) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + result[accountID] = 0 + } + return result, nil +} + +func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + return true, nil +} + +func (m *concurrencyCacheMock) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { + return nil +} + +func (m *concurrencyCacheMock) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (m *concurrencyCacheMock) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + if m.acquireUserSlotFn != nil { + return m.acquireUserSlotFn(ctx, userID, maxConcurrency, requestID) + } + return false, nil +} + +func (m *concurrencyCacheMock) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { + atomic.AddInt32(&m.releaseUserCalled, 1) + return nil +} + +func (m *concurrencyCacheMock) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (m *concurrencyCacheMock) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { + return true, nil +} + +func (m *concurrencyCacheMock) DecrementWaitCount(ctx context.Context, userID int64) error { + return nil +} + +func (m *concurrencyCacheMock) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + return map[int64]*service.AccountLoadInfo{}, nil +} + +func (m *concurrencyCacheMock) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + return map[int64]*service.UserLoadInfo{}, nil +} + +func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + return nil +} + +func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) { + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second) + + release, acquired, err := helper.TryAcquireUserSlot(context.Background(), 101, 2) + require.NoError(t, err) + require.True(t, acquired) + require.NotNil(t, release) + + release() + require.Equal(t, int32(1), atomic.LoadInt32(&cache.releaseUserCalled)) +} + +func TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired(t *testing.T) { + cache := &concurrencyCacheMock{ + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return false, nil + }, + } + helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second) + + release, acquired, err := helper.TryAcquireAccountSlot(context.Background(), 201, 1) + require.NoError(t, err) + require.False(t, acquired) + require.Nil(t, release) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.releaseAccountCalled)) +} diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go new file mode 100644 index 00000000..f8f7eaca --- /dev/null +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -0,0 +1,317 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type helperConcurrencyCacheStub struct { + mu sync.Mutex + + accountSeq []bool + userSeq []bool + + accountAcquireCalls int + userAcquireCalls int + accountReleaseCalls int + userReleaseCalls int +} + +func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.accountAcquireCalls++ + if len(s.accountSeq) == 0 { + return false, nil + } + v := s.accountSeq[0] + s.accountSeq = s.accountSeq[1:] + return v, nil +} + +func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.accountReleaseCalls++ + return nil +} + +func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (s *helperConcurrencyCacheStub) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + out := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + out[accountID] = 0 + } + return out, nil +} + +func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + return true, nil +} + +func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { + return nil +} + +func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.userAcquireCalls++ + if len(s.userSeq) == 0 { + return false, nil + } + v := s.userSeq[0] + s.userSeq = s.userSeq[1:] + return v, nil +} + +func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.userReleaseCalls++ + return nil +} + +func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { + return true, nil +} + +func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error { + return nil +} + +func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + out := make(map[int64]*service.AccountLoadInfo, len(accounts)) + for _, acc := range accounts { + out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID} + } + return out, nil +} + +func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + out := make(map[int64]*service.UserLoadInfo, len(users)) + for _, user := range users { + out[user.ID] = &service.UserLoadInfo{UserID: user.ID} + } + return out, nil +} + +func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + return nil +} + +func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(method, path, nil) + return c, rec +} + +func validClaudeCodeBodyJSON() []byte { + return []byte(`{ + "model":"claude-3-5-sonnet-20241022", + "system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}], + "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"} + }`) +} + +func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { + t.Run("non_cli_user_agent_sets_false", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "curl/8.6.0") + + SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil) + require.False(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("cli_non_messages_path_sets_true", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodGet, "/v1/models") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + + SetClaudeCodeClientContext(c, nil, nil) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + c.Request.Header.Set("X-App", "claude-code") + c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") + c.Request.Header.Set("anthropic-version", "2023-06-01") + + SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + // 缺少严格校验所需 header + body 字段 + SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`), nil) + require.False(t, service.IsClaudeCodeClient(c.Request.Context())) + }) +} + +func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing.T) { + t.Run("reuse parsed request without body unmarshal", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + c.Request.Header.Set("X-App", "claude-code") + c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") + c.Request.Header.Set("anthropic-version", "2023-06-01") + + parsedReq := &service.ParsedRequest{ + Model: "claude-3-5-sonnet-20241022", + System: []any{ + map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, + }, + MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123", + } + + // body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。 + SetClaudeCodeClientContext(c, []byte(`{invalid`), parsedReq) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("reuse context cache without body unmarshal", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + c.Request.Header.Set("X-App", "claude-code") + c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") + c.Request.Header.Set("anthropic-version", "2023-06-01") + c.Set(service.OpenAIParsedRequestBodyKey, map[string]any{ + "model": "claude-3-5-sonnet-20241022", + "system": []any{ + map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, + }, + "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}, + }) + + SetClaudeCodeClientContext(c, []byte(`{invalid`), nil) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) +} + +func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false, true}, + userSeq: []bool{false, true}, + } + concurrency := service.NewConcurrencyService(cache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + + t.Run("account_slot_acquired_after_retry", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true) + require.NoError(t, err) + require.NotNil(t, release) + require.False(t, streamStarted) + release() + require.GreaterOrEqual(t, cache.accountAcquireCalls, 2) + require.GreaterOrEqual(t, cache.accountReleaseCalls, 1) + }) + + t.Run("user_slot_acquired_after_retry", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true) + require.NoError(t, err) + require.NotNil(t, release) + release() + require.GreaterOrEqual(t, cache.userAcquireCalls, 2) + require.GreaterOrEqual(t, cache.userReleaseCalls, 1) + }) +} + +func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false, false, false}, + } + concurrency := service.NewConcurrencyService(cache) + + t.Run("timeout_returns_concurrency_error", func(t *testing.T) { + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true) + require.Nil(t, release) + var cErr *ConcurrencyError + require.ErrorAs(t, err, &cErr) + require.True(t, cErr.IsTimeout) + }) + + t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) { + helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond) + c, rec := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true) + require.Nil(t, release) + var cErr *ConcurrencyError + require.ErrorAs(t, err, &cErr) + require.True(t, cErr.IsTimeout) + require.True(t, streamStarted) + require.Contains(t, rec.Body.String(), ":\n\n") + }) +} + +func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) { + errCache := &helperConcurrencyCacheStubWithError{ + err: errors.New("redis unavailable"), + } + concurrency := service.NewConcurrencyService(errCache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true) + require.Nil(t, release) + require.Error(t, err) + require.Contains(t, err.Error(), "redis unavailable") +} + +func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false}, + } + concurrency := service.NewConcurrencyService(cache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + + release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted) + require.Nil(t, release) + var cErr *ConcurrencyError + require.ErrorAs(t, err, &cErr) + require.True(t, cErr.IsTimeout) + require.GreaterOrEqual(t, cache.accountAcquireCalls, 1) +} + +type helperConcurrencyCacheStubWithError struct { + helperConcurrencyCacheStub + err error +} + +func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return false, s.err +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index b1477ac6..50af9c8f 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -7,36 +7,29 @@ import ( "encoding/hex" "encoding/json" "errors" - "io" - "log" "net/http" "regexp" "strings" - "time" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/google/uuid" "github.com/gin-gonic/gin" + "go.uber.org/zap" ) // geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值 // 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希] var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`) -func isGeminiCLIRequest(c *gin.Context, body []byte) bool { - if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" { - return true - } - return geminiCLITmpDirRegex.Match(body) -} - // GeminiV1BetaListModels proxies: // GET /v1beta/models func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { @@ -149,6 +142,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusInternalServerError, "User context not found") return } + reqLog := requestLogger( + c, + "handler.gemini_v1beta.models", + zap.Int64("user_id", authSubject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) // 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组 if !middleware.HasForcePlatform(c) { @@ -165,8 +165,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } stream := action == "streamGenerateContent" + reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream)) - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit)) @@ -193,8 +194,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait) waitCounted := false if err != nil { - log.Printf("Increment wait count failed: %v", err) + reqLog.Warn("gemini.user_wait_counter_increment_failed", zap.Error(err)) } else if !canWait { + reqLog.Info("gemini.user_wait_queue_full", zap.Int("max_wait", maxWait)) googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") return } @@ -214,6 +216,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted) if err != nil { + reqLog.Warn("gemini.user_slot_acquire_failed", zap.Error(err)) googleError(c, http.StatusTooManyRequests, err.Error()) return } @@ -229,6 +232,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 2) billing eligibility check (after wait) if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err)) status, _, message := billingErrorDetails(err) googleError(c, status, message) return @@ -239,7 +243,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { sessionHash := extractGeminiCLISessionHash(c, body) if sessionHash == "" { // Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端) - parsedReq, _ := service.ParseGatewayRequest(body) + parsedReq, _ := service.ParseGatewayRequest(body, domain.PlatformGemini) + if parsedReq != nil { + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } + } sessionHash = h.gatewayService.GenerateSessionHash(parsedReq) } sessionKey := sessionHash @@ -251,6 +262,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var sessionBoundAccountID int64 if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + if sessionBoundAccountID > 0 { + prefetchedGroupID := int64(0) + if apiKey.GroupID != nil { + prefetchedGroupID = *apiKey.GroupID + } + ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + } } // === Gemini 内容摘要会话 Fallback 逻辑 === @@ -258,6 +277,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var geminiDigestChain string var geminiPrefixHash string var geminiSessionUUID string + var matchedDigestChain string useDigestFallback := sessionBoundAccountID == 0 if useDigestFallback { @@ -284,17 +304,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ) // 查找会话 - foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession( + foundUUID, foundAccountID, foundMatchedChain, found := h.gatewayService.FindGeminiSession( c.Request.Context(), derefGroupID(apiKey.GroupID), geminiPrefixHash, geminiDigestChain, ) if found { + matchedDigestChain = foundMatchedChain sessionBoundAccountID = foundAccountID geminiSessionUUID = foundUUID - log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s", - safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain)) + reqLog.Info("gemini.digest_fallback_matched", + zap.String("session_uuid_prefix", safeShortPrefix(foundUUID, 8)), + zap.Int64("account_id", foundAccountID), + zap.String("digest_chain", truncateDigestChain(geminiDigestChain)), + ) // 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey // 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号 @@ -316,38 +340,56 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 - isCLI := isGeminiCLIRequest(c, body) cleanedForUnknownBinding := false - maxAccountSwitches := h.maxAccountSwitchesGemini - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - var lastFailoverErr *service.UpstreamFailoverError - var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 + fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession) + + // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 + // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 + if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) { + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { - if len(failedAccountIDs) == 0 { + if len(fs.FailedAccountIDs) == 0 { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } - h.handleGeminiFailoverExhausted(c, lastFailoverErr) - return + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + continue + case FailoverCanceled: + return + default: // FailoverExhausted + h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr) + return + } } account := selection.Account - setOpsSelectedAccount(c, account.ID) + setOpsSelectedAccount(c, account.ID, account.Platform) // 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature // 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。 if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID { - log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID) + reqLog.Info("gemini.sticky_session_account_switched", + zap.Int64("from_account_id", sessionBoundAccountID), + zap.Int64("to_account_id", account.ID), + zap.Bool("clean_thought_signature", true), + ) body = service.CleanGeminiNativeThoughtSignatures(body) sessionBoundAccountID = account.ID - } else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) { - // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。 + } else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) { + // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。 // 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。 - log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively") + reqLog.Info("gemini.sticky_session_binding_missing", + zap.Bool("clean_thought_signature", true), + ) body = service.CleanGeminiNativeThoughtSignatures(body) cleanedForUnknownBinding = true sessionBoundAccountID = account.ID @@ -366,9 +408,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { accountWaitCounted := false canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { - log.Printf("Increment account wait count failed: %v", err) + reqLog.Warn("gemini.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) + reqLog.Info("gemini.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") return } @@ -390,6 +435,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { &streamStarted, ) if err != nil { + reqLog.Warn("gemini.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) googleError(c, http.StatusTooManyRequests, err.Error()) return } @@ -398,7 +444,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { accountWaitCounted = false } if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) + reqLog.Warn("gemini.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } // 账号槽位/等待计数需要在超时或断开时安全回收 @@ -407,10 +453,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 5) forward (根据平台分流) var result *service.ForwardResult requestCtx := c.Request.Context() - if switchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + if fs.SwitchCount > 0 { + requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } - if account.Platform == service.PlatformAntigravity { + if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) } else { result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body) @@ -421,22 +467,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - if failoverErr.ForceCacheBilling { - forceCacheBilling = true - } - if switchCount >= maxAccountSwitches { - lastFailoverErr = failoverErr - h.handleGeminiFailoverExhausted(c, lastFailoverErr) + failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch failoverAction { + case FailoverContinue: + continue + case FailoverExhausted: + h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr) + return + case FailoverCanceled: return } - lastFailoverErr = failoverErr - switchCount++ - log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) - continue } // ForwardNative already wrote the response - log.Printf("Gemini native forward failed: %v", err) + reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) return } @@ -453,32 +496,41 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { geminiDigestChain, geminiSessionUUID, account.ID, + matchedDigestChain, ); err != nil { - log.Printf("[Gemini] Failed to save digest session: %v", err) + reqLog.Warn("gemini.digest_session_save_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } - // 6) record usage async (Gemini 使用长上下文双倍计费) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + UserAgent: userAgent, + IPAddress: clientIP, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 - ForceCacheBilling: fcb, + ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { - log.Printf("Record usage failed: %v", err) + logger.L().With( + zap.String("component", "handler.gemini_v1beta.models"), + zap.Int64("user_id", authSubject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", modelName), + zap.Int64("account_id", account.ID), + ).Error("gemini.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP, forceCacheBilling) + }) + reqLog.Debug("gemini.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", fs.SwitchCount), + ) return } } @@ -526,6 +578,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE msg = *rule.CustomMessage } + if rule.SkipMonitoring { + c.Set(service.OpsSkipPassthroughKey, true) + } + googleError(c, respCode, msg) return } diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b2b12c0d..1e1247fc 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -11,6 +11,7 @@ type AdminHandlers struct { Group *admin.GroupHandler Account *admin.AccountHandler Announcement *admin.AnnouncementHandler + DataManagement *admin.DataManagementHandler OAuth *admin.OAuthHandler OpenAIOAuth *admin.OpenAIOAuthHandler GeminiOAuth *admin.GeminiOAuthHandler @@ -25,6 +26,7 @@ type AdminHandlers struct { Usage *admin.UsageHandler UserAttribute *admin.UserAttributeHandler ErrorPassthrough *admin.ErrorPassthroughHandler + APIKey *admin.AdminAPIKeyHandler } // Handlers contains all HTTP handlers @@ -39,6 +41,8 @@ type Handlers struct { Admin *AdminHandlers Gateway *GatewayHandler OpenAIGateway *OpenAIGatewayHandler + SoraGateway *SoraGatewayHandler + SoraClient *SoraClientHandler Setting *SettingHandler Totp *TotpHandler } diff --git a/backend/internal/handler/idempotency_helper.go b/backend/internal/handler/idempotency_helper.go new file mode 100644 index 00000000..bca63b6b --- /dev/null +++ b/backend/internal/handler/idempotency_helper.go @@ -0,0 +1,65 @@ +package handler + +import ( + "context" + "strconv" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +func executeUserIdempotentJSON( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) { + coordinator := service.DefaultIdempotencyCoordinator() + if coordinator == nil { + data, err := execute(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, data) + return + } + + actorScope := "user:0" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + actorScope = "user:" + strconv.FormatInt(subject.UserID, 10) + } + + result, err := coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{ + Scope: scope, + ActorScope: actorScope, + Method: c.Request.Method, + Route: c.FullPath(), + IdempotencyKey: c.GetHeader("Idempotency-Key"), + Payload: payload, + RequireKey: true, + TTL: ttl, + }, execute) + if err != nil { + if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) { + service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_fail_close") + logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=fail_close", c.Request.Method, c.FullPath(), scope) + } + if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } + response.ErrorFrom(c, err) + return + } + if result != nil && result.Replayed { + c.Header("X-Idempotency-Replayed", "true") + } + response.Success(c, result.Data) +} diff --git a/backend/internal/handler/idempotency_helper_test.go b/backend/internal/handler/idempotency_helper_test.go new file mode 100644 index 00000000..e8213a2b --- /dev/null +++ b/backend/internal/handler/idempotency_helper_test.go @@ -0,0 +1,285 @@ +package handler + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userStoreUnavailableRepoStub struct{} + +func (userStoreUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) { + return false, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) { + return nil, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, errors.New("store unavailable") +} + +type userMemoryIdempotencyRepoStub struct { + mu sync.Mutex + nextID int64 + data map[string]*service.IdempotencyRecord +} + +func newUserMemoryIdempotencyRepoStub() *userMemoryIdempotencyRepoStub { + return &userMemoryIdempotencyRepoStub{ + nextID: 1, + data: make(map[string]*service.IdempotencyRecord), + } +} + +func (r *userMemoryIdempotencyRepoStub) key(scope, keyHash string) string { + return scope + "|" + keyHash +} + +func (r *userMemoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord { + if in == nil { + return nil + } + out := *in + if in.LockedUntil != nil { + v := *in.LockedUntil + out.LockedUntil = &v + } + if in.ResponseBody != nil { + v := *in.ResponseBody + out.ResponseBody = &v + } + if in.ResponseStatus != nil { + v := *in.ResponseStatus + out.ResponseStatus = &v + } + if in.ErrorReason != nil { + v := *in.ErrorReason + out.ErrorReason = &v + } + return &out +} + +func (r *userMemoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + k := r.key(record.Scope, record.IdempotencyKeyHash) + if _, ok := r.data[k]; ok { + return false, nil + } + cp := r.clone(record) + cp.ID = r.nextID + r.nextID++ + r.data[k] = cp + record.ID = cp.ID + return true, nil +} + +func (r *userMemoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) { + r.mu.Lock() + defer r.mu.Unlock() + return r.clone(r.data[r.key(scope, keyHash)]), nil +} + +func (r *userMemoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != fromStatus { + return false, nil + } + if rec.LockedUntil != nil && rec.LockedUntil.After(now) { + return false, nil + } + rec.Status = service.IdempotencyStatusProcessing + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.ErrorReason = nil + return true, nil + } + return false, nil +} + +func (r *userMemoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint { + return false, nil + } + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + return true, nil + } + return false, nil +} + +func (r *userMemoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusSucceeded + rec.LockedUntil = nil + rec.ExpiresAt = expiresAt + rec.ResponseStatus = &responseStatus + rec.ResponseBody = &responseBody + rec.ErrorReason = nil + return nil + } + return nil +} + +func (r *userMemoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusFailedRetryable + rec.LockedUntil = &lockedUntil + rec.ExpiresAt = expiresAt + rec.ErrorReason = &errorReason + return nil + } + return nil +} + +func (r *userMemoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) { + return 0, nil +} + +func withUserSubject(userID int64) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID}) + c.Next() + } +} + +func TestExecuteUserIdempotentJSONFallbackWithoutCoordinator(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(nil) + + var executed int + router := gin.New() + router.Use(withUserSubject(1)) + router.POST("/idempotent", func(c *gin.Context) { + executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 1, executed) +} + +func TestExecuteUserIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(userStoreUnavailableRepoStub{}, service.DefaultIdempotencyConfig())) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed int + router := gin.New() + router.Use(withUserSubject(2)) + router.POST("/idempotent", func(c *gin.Context) { + executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "k1") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + require.Equal(t, 0, executed) +} + +func TestExecuteUserIdempotentJSONConcurrentRetrySingleSideEffectAndReplay(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := newUserMemoryIdempotencyRepoStub() + cfg := service.DefaultIdempotencyConfig() + cfg.ProcessingTimeout = 2 * time.Second + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg)) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed atomic.Int32 + router := gin.New() + router.Use(withUserSubject(3)) + router.POST("/idempotent", func(c *gin.Context) { + executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed.Add(1) + time.Sleep(80 * time.Millisecond) + return gin.H{"ok": true}, nil + }) + }) + + call := func() (int, http.Header) { + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "same-user-key") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + return rec.Code, rec.Header() + } + + var status1, status2 int + var wg sync.WaitGroup + wg.Add(2) + go func() { defer wg.Done(); status1, _ = call() }() + go func() { defer wg.Done(); status2, _ = call() }() + wg.Wait() + + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1) + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2) + require.Equal(t, int32(1), executed.Load()) + + status3, headers3 := call() + require.Equal(t, http.StatusOK, status3) + require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed")) + require.Equal(t, int32(1), executed.Load()) +} diff --git a/backend/internal/handler/logging.go b/backend/internal/handler/logging.go new file mode 100644 index 00000000..2d5e6e22 --- /dev/null +++ b/backend/internal/handler/logging.go @@ -0,0 +1,19 @@ +package handler + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +func requestLogger(c *gin.Context, component string, fields ...zap.Field) *zap.Logger { + base := logger.L() + if c != nil && c.Request != nil { + base = logger.FromContext(c.Request.Context()) + } + + if component != "" { + fields = append([]zap.Field{zap.String("component", component)}, fields...) + } + return base.With(fields...) +} diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 835297b8..4bbd17ba 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -5,19 +5,23 @@ import ( "encoding/json" "errors" "fmt" - "io" - "log" "net/http" + "runtime/debug" + "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" - "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" + coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" ) // OpenAIGatewayHandler handles OpenAI API gateway requests @@ -25,6 +29,7 @@ type OpenAIGatewayHandler struct { gatewayService *service.OpenAIGatewayService billingCacheService *service.BillingCacheService apiKeyService *service.APIKeyService + usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int @@ -36,6 +41,7 @@ func NewOpenAIGatewayHandler( concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, apiKeyService *service.APIKeyService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *OpenAIGatewayHandler { @@ -51,6 +57,7 @@ func NewOpenAIGatewayHandler( gatewayService: gatewayService, billingCacheService: billingCacheService, apiKeyService: apiKeyService, + usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, @@ -60,6 +67,13 @@ func NewOpenAIGatewayHandler( // Responses handles OpenAI Responses API endpoint // POST /openai/v1/responses func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { + // 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。 + streamStarted := false + defer h.recoverResponsesPanic(c, &streamStarted) + setOpenAIClientTransportHTTP(c) + + requestStart := time.Now() + // Get apiKey and user from context (set by ApiKeyAuth middleware) apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { @@ -72,9 +86,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } + reqLog := requestLogger( + c, + "handler.openai_gateway.responses", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + if !h.ensureResponsesDependencies(c, reqLog) { + return + } // Read request body - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -91,64 +115,51 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, "", false, body) - // Parse request body to map for potential modification - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err != nil { + // 校验请求体 JSON 合法性 + if !gjson.ValidBytes(body) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } - // Extract model and stream - reqModel, _ := reqBody["model"].(string) - reqStream, _ := reqBody["stream"].(bool) - - // 验证 model 必填 - if reqModel == "" { + // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") return } + reqModel := modelResult.String() - userAgent := c.GetHeader("User-Agent") - if !openai.IsCodexCLIRequest(userAgent) { - existingInstructions, _ := reqBody["instructions"].(string) - if strings.TrimSpace(existingInstructions) == "" { - if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" { - reqBody["instructions"] = instructions - // Re-serialize body - body, err = json.Marshal(reqBody) - if err != nil { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") - return - } - } + streamResult := gjson.GetBytes(body, "stream") + if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type") + return + } + reqStream := streamResult.Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()) + if previousResponseID != "" { + previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + reqLog = reqLog.With( + zap.Bool("has_previous_response_id", true), + zap.String("previous_response_id_kind", previousResponseIDKind), + zap.Int("previous_response_id_len", len(previousResponseID)), + ) + if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID { + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "previous_response_id_looks_like_message_id"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id") + return } } setOpsRequestContext(c, reqModel, reqStream, body) // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 - // 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call, - // 或带 id 且与 call_id 匹配的 item_reference。 - if service.HasFunctionCallOutput(reqBody) { - previousResponseID, _ := reqBody["previous_response_id"].(string) - if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) { - if service.HasFunctionCallOutputMissingCallID(reqBody) { - log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") - return - } - callIDs := service.FunctionCallOutputCallIDs(reqBody) - if !service.HasItemReferenceForCallIDs(reqBody, callIDs) { - log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") - return - } - } + if !h.validateFunctionCallOutputRequest(c, body, reqLog) { + return } - // Track if we've started streaming (for error handling) - streamStarted := false - // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) @@ -157,54 +168,28 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Get subscription info (may be nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) - // 0. Check if wait queue is full - maxWait := service.CalculateMaxWait(subject.Concurrency) - canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) - waitCounted := false - if err != nil { - log.Printf("Increment wait count failed: %v", err) - // On error, allow request to proceed - } else if !canWait { - h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") - return - } - if err == nil && canWait { - waitCounted = true - } - defer func() { - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - } - }() + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + routingStart := time.Now() - // 1. First acquire user concurrency slot - userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) - if err != nil { - log.Printf("User concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "user", streamStarted) + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog) + if !acquired { return } - // User slot acquired: no longer waiting. - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - waitCounted = false - } // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 - userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { defer userReleaseFunc() } // 2. Re-check billing eligibility after wait if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - log.Printf("Billing eligibility check failed after wait: %v", err) + reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err)) status, code, message := billingErrorDetails(err) h.handleStreamingAwareError(c, status, code, message, streamStarted) return } // Generate session hash (header first; fallback to prompt_cache_key) - sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody) + sessionHash := h.gatewayService.GenerateSessionHash(c, body) maxAccountSwitches := h.maxAccountSwitches switchCount := 0 @@ -213,12 +198,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { for { // Select account supporting the requested model - log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + previousResponseID, + sessionHash, + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) if err != nil { - log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) + reqLog.Warn("openai.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } if lastFailoverErr != nil { @@ -228,67 +224,53 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } return } - account := selection.Account - log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) - setOpsSelectedAccount(c, account.ID) - - // 3. Acquire account concurrency slot - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - defer func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - }() - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + if previousResponseID != "" && selection != nil && selection.Account != nil { + reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID)) + } + reqLog.Debug("openai.account_schedule_decision", + zap.String("layer", scheduleDecision.Layer), + zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit), + zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit), + zap.Int("candidate_count", scheduleDecision.CandidateCount), + zap.Int("top_k", scheduleDecision.TopK), + zap.Int64("latency_ms", scheduleDecision.LatencyMs), + zap.Float64("load_skew", scheduleDecision.LoadSkew), + ) + account := selection.Account + reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) + setOpsSelectedAccount(c, account.ID, account.Platform) + + accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog) + if !acquired { + return } - // 账号槽位/等待计数需要在超时或断开时安全回收 - accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) // Forward request + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) + forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { accountReleaseFunc() } + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + if err == nil && result != nil && result.FirstTokenMs != nil { + service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) + } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + h.gatewayService.RecordOpenAIAccountSwitch() failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { @@ -296,39 +278,631 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + reqLog.Warn("openai.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) continue } - // Error response already handled in Forward, just log - log.Printf("Account %d: Forward request failed: %v", account.ID, err) + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + } + if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { + reqLog.Warn("openai.forward_failed", fields...) + return + } + reqLog.Error("openai.forward_failed", fields...) return } + if result != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + } else { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + } // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // Async record usage - go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + UserAgent: userAgent, + IPAddress: clientIP, APIKeyService: h.apiKeyService, }); err != nil { - log.Printf("Record usage failed: %v", err) + logger.L().With( + zap.String("component", "handler.openai_gateway.responses"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("openai.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP) + }) + reqLog.Debug("openai.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) return } } +func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool { + if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { + return true + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + // 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。 + return true + } + + c.Set(service.OpenAIParsedRequestBodyKey, reqBody) + validation := service.ValidateFunctionCallOutputContext(reqBody) + if !validation.HasFunctionCallOutput { + return true + } + + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext { + return true + } + + if validation.HasFunctionCallOutputMissingCallID { + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "function_call_output_missing_call_id"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") + return false + } + if validation.HasItemReferenceForAllCallIDs { + return true + } + + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "function_call_output_missing_item_reference"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") + return false +} + +func (h *OpenAIGatewayHandler) acquireResponsesUserSlot( + c *gin.Context, + userID int64, + userConcurrency int, + reqStream bool, + streamStarted *bool, + reqLog *zap.Logger, +) (func(), bool) { + ctx := c.Request.Context() + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency) + if err != nil { + reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", *streamStarted) + return nil, false + } + if userAcquired { + return wrapReleaseOnDone(ctx, userReleaseFunc), true + } + + maxWait := service.CalculateMaxWait(userConcurrency) + canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait) + if waitErr != nil { + reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr)) + // 按现有降级语义:等待计数异常时放行后续抢槽流程 + } else if !canWait { + reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait)) + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return nil, false + } + + waitCounted := waitErr == nil && canWait + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(ctx, userID) + } + }() + + userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted) + if err != nil { + reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", *streamStarted) + return nil, false + } + + // 槽位获取成功后,立刻退出等待计数。 + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(ctx, userID) + waitCounted = false + } + return wrapReleaseOnDone(ctx, userReleaseFunc), true +} + +func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot( + c *gin.Context, + groupID *int64, + sessionHash string, + selection *service.AccountSelectionResult, + reqStream bool, + streamStarted *bool, + reqLog *zap.Logger, +) (func(), bool) { + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted) + return nil, false + } + + ctx := c.Request.Context() + account := selection.Account + if selection.Acquired { + return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true + } + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted) + return nil, false + } + + fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + ctx, + account.ID, + selection.WaitPlan.MaxConcurrency, + ) + if err != nil { + reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", *streamStarted) + return nil, false + } + if fastAcquired { + if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + return wrapReleaseOnDone(ctx, fastReleaseFunc), true + } + + canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting) + if waitErr != nil { + reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr)) + } else if !canWait { + reqLog.Info("openai.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted) + return nil, false + } + + accountWaitCounted := waitErr == nil && canWait + releaseWait := func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID) + accountWaitCounted = false + } + } + defer releaseWait() + + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + streamStarted, + ) + if err != nil { + reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", *streamStarted) + return nil, false + } + + // Slot acquired: no longer waiting in queue. + releaseWait() + if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + return wrapReleaseOnDone(ctx, accountReleaseFunc), true +} + +// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint +// GET /openai/v1/responses (Upgrade: websocket) +func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { + if !isOpenAIWSUpgradeRequest(c.Request) { + h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)") + return + } + setOpenAIClientTransportWS(c) + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + + reqLog := requestLogger( + c, + "handler.openai_gateway.responses_ws", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.Bool("openai_ws_mode", true), + ) + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + reqLog.Info("openai.websocket_ingress_started") + clientIP := ip.GetClientIP(c) + userAgent := strings.TrimSpace(c.GetHeader("User-Agent")) + + wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + reqLog.Warn("openai.websocket_accept_failed", + zap.Error(err), + zap.String("client_ip", clientIP), + zap.String("request_user_agent", userAgent), + zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))), + zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))), + zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))), + zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""), + ) + return + } + defer func() { + _ = wsConn.CloseNow() + }() + wsConn.SetReadLimit(16 * 1024 * 1024) + + ctx := c.Request.Context() + readCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + msgType, firstMessage, err := wsConn.Read(readCtx) + cancel() + if err != nil { + closeStatus, closeReason := summarizeWSCloseErrorForLog(err) + reqLog.Warn("openai.websocket_read_first_message_failed", + zap.Error(err), + zap.String("client_ip", clientIP), + zap.String("close_status", closeStatus), + zap.String("close_reason", closeReason), + zap.Duration("read_timeout", 30*time.Second), + ) + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message") + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type") + return + } + if !gjson.ValidBytes(firstMessage) { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload") + return + } + + reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String()) + if reqModel == "" { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload") + return + } + previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String()) + previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id") + return + } + reqLog = reqLog.With( + zap.Bool("ws_ingress", true), + zap.String("model", reqModel), + zap.Bool("has_previous_response_id", previousResponseID != ""), + zap.String("previous_response_id_kind", previousResponseIDKind), + ) + setOpsRequestContext(c, reqModel, true, firstMessage) + + var currentUserRelease func() + var currentAccountRelease func() + releaseTurnSlots := func() { + if currentAccountRelease != nil { + currentAccountRelease() + currentAccountRelease = nil + } + if currentUserRelease != nil { + currentUserRelease() + currentUserRelease = nil + } + } + // 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。 + defer releaseTurnSlots() + + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) + if err != nil { + reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot") + return + } + if !userAcquired { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later") + return + } + currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed") + return + } + + sessionHash := h.gatewayService.GenerateSessionHashWithFallback( + c, + firstMessage, + openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID), + ) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + ctx, + apiKey.GroupID, + previousResponseID, + sessionHash, + reqModel, + nil, + service.OpenAIUpstreamTransportResponsesWebsocketV2, + ) + if err != nil { + reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") + return + } + if selection == nil || selection.Account == nil { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") + return + } + + account := selection.Account + accountMaxConcurrency := account.Concurrency + if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 { + accountMaxConcurrency = selection.WaitPlan.MaxConcurrency + } + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") + return + } + fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + ctx, + account.ID, + selection.WaitPlan.MaxConcurrency, + ) + if err != nil { + reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot") + return + } + if !fastAcquired { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") + return + } + accountReleaseFunc = fastReleaseFunc + } + currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) + if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + + token, _, err := h.gatewayService.GetAccessToken(ctx, account) + if err != nil { + reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") + return + } + + reqLog.Debug("openai.websocket_account_selected", + zap.Int64("account_id", account.ID), + zap.String("account_name", account.Name), + zap.String("schedule_layer", scheduleDecision.Layer), + zap.Int("candidate_count", scheduleDecision.CandidateCount), + ) + + hooks := &service.OpenAIWSIngressHooks{ + BeforeTurn: func(turn int) error { + if turn == 1 { + return nil + } + // 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。 + releaseTurnSlots() + // 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。 + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) + if err != nil { + return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err) + } + if !userAcquired { + return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil) + } + accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency) + if err != nil { + if userReleaseFunc != nil { + userReleaseFunc() + } + return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err) + } + if !accountAcquired { + if userReleaseFunc != nil { + userReleaseFunc() + } + return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil) + } + currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) + return nil + }, + AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { + releaseTurnSlots() + if turnErr != nil || result == nil { + return + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + h.submitUsageRecordTask(func(taskCtx context.Context) { + if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("openai.websocket_record_usage_failed", + zap.Int64("account_id", account.ID), + zap.String("request_id", result.RequestID), + zap.Error(err), + ) + } + }) + }, + } + + if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + closeStatus, closeReason := summarizeWSCloseErrorForLog(err) + reqLog.Warn("openai.websocket_proxy_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + zap.String("close_status", closeStatus), + zap.String("close_reason", closeReason), + ) + var closeErr *service.OpenAIWSClientCloseError + if errors.As(err, &closeErr) { + closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason()) + return + } + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed") + return + } + reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID)) +} + +func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) { + recovered := recover() + if recovered == nil { + return + } + + started := false + if streamStarted != nil { + started = *streamStarted + } + wroteFallback := h.ensureForwardErrorResponse(c, started) + requestLogger(c, "handler.openai_gateway.responses").Error( + "openai.responses_panic_recovered", + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Any("panic", recovered), + zap.ByteString("stack", debug.Stack()), + ) +} + +func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool { + missing := h.missingResponsesDependencies() + if len(missing) == 0 { + return true + } + + if reqLog == nil { + reqLog = requestLogger(c, "handler.openai_gateway.responses") + } + reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing)) + + if c != nil && c.Writer != nil && !c.Writer.Written() { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Service temporarily unavailable", + }, + }) + } + return false +} + +func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string { + missing := make([]string, 0, 5) + if h == nil { + return append(missing, "handler") + } + if h.gatewayService == nil { + missing = append(missing, "gatewayService") + } + if h.billingCacheService == nil { + missing = append(missing, "billingCacheService") + } + if h.apiKeyService == nil { + missing = append(missing, "apiKeyService") + } + if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil { + missing = append(missing, "concurrencyHelper") + } + return missing +} + +func getContextInt64(c *gin.Context, key string) (int64, bool) { + if c == nil || key == "" { + return 0, false + } + v, ok := c.Get(key) + if !ok { + return 0, false + } + switch t := v.(type) { + case int64: + return t, true + case int: + return int64(t), true + case int32: + return int64(t), true + case float64: + return int64(t), true + default: + return 0, false + } +} + +func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.responses"), + zap.Any("panic", recovered), + ).Error("openai.usage_record_task_panic_recovered") + } + }() + task(ctx) +} + // handleConcurrencyError handles concurrency-related errors with proper 429 response func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", @@ -354,6 +928,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE msg = *rule.CustomMessage } + if rule.SkipMonitoring { + c.Set(service.OpsSkipPassthroughKey, true) + } + h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) return } @@ -393,8 +971,8 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Send error event in OpenAI SSE format - errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + // SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。 + errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n" if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } @@ -407,6 +985,25 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status h.errorResponse(c, status, errType, message) } +// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 +func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) + return true +} + +func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool { + if wroteFallback { + return false + } + if c == nil || c.Writer == nil { + return false + } + return c.Writer.Written() +} + // errorResponse returns OpenAI API format error response func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ @@ -416,3 +1013,61 @@ func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType }, }) } + +func setOpenAIClientTransportHTTP(c *gin.Context) { + service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP) +} + +func setOpenAIClientTransportWS(c *gin.Context) { + service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS) +} + +func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string { + gid := int64(0) + if groupID != nil { + gid = *groupID + } + return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID) +} + +func isOpenAIWSUpgradeRequest(r *http.Request) bool { + if r == nil { + return false + } + if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") { + return false + } + return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade") +} + +func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) { + if conn == nil { + return + } + reason = strings.TrimSpace(reason) + if len(reason) > 120 { + reason = reason[:120] + } + _ = conn.Close(status, reason) + _ = conn.CloseNow() +} + +func summarizeWSCloseErrorForLog(err error) (string, string) { + if err == nil { + return "-", "-" + } + statusCode := coderws.CloseStatus(err) + if statusCode == -1 { + return "-", "-" + } + closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) + closeReason := "-" + var closeErr coderws.CloseError + if errors.As(err, &closeErr) { + reason := strings.TrimSpace(closeErr.Reason) + if reason != "" { + closeReason = reason + } + } + return closeStatus, closeReason +} diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go new file mode 100644 index 00000000..a26b3a0c --- /dev/null +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -0,0 +1,677 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) { + tests := []struct { + name string + errType string + message string + }{ + { + name: "包含双引号的消息", + errType: "server_error", + message: `upstream returned "invalid" response`, + }, + { + name: "包含反斜杠的消息", + errType: "server_error", + message: `path C:\Users\test\file.txt not found`, + }, + { + name: "包含双引号和反斜杠的消息", + errType: "upstream_error", + message: `error parsing "key\value": unexpected token`, + }, + { + name: "包含换行符的消息", + errType: "server_error", + message: "line1\nline2\ttab", + }, + { + name: "普通消息", + errType: "upstream_error", + message: "Upstream service temporarily unavailable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true) + + body := w.Body.String() + + // 验证 SSE 格式:event: error\ndata: {JSON}\n\n + assert.True(t, strings.HasPrefix(body, "event: error\n"), "应以 'event: error\\n' 开头") + assert.True(t, strings.HasSuffix(body, "\n\n"), "应以 '\\n\\n' 结尾") + + // 提取 data 部分 + lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n") + require.Len(t, lines, 2, "应有 event 行和 data 行") + dataLine := lines[1] + require.True(t, strings.HasPrefix(dataLine, "data: "), "第二行应以 'data: ' 开头") + jsonStr := strings.TrimPrefix(dataLine, "data: ") + + // 验证 JSON 合法性 + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "JSON 应能被成功解析,原始 JSON: %s", jsonStr) + + // 验证结构 + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok, "应包含 error 对象") + assert.Equal(t, tt.errType, errorObj["type"]) + assert.Equal(t, tt.message, errorObj["message"]) + }) + } +} + +func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "test error", false) + + // 非流式应返回 JSON 响应 + assert.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "test error", errorObj["message"]) +} + +func TestReadRequestBodyWithPrealloc(t *testing.T) { + payload := `{"model":"gpt-5","input":"hello"}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(payload)) + req.ContentLength = int64(len(payload)) + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(req) + require.NoError(t, err) + require.Equal(t, payload, string(body)) +} + +func TestReadRequestBodyWithPrealloc_MaxBytesError(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(strings.Repeat("x", 8))) + req.Body = http.MaxBytesReader(rec, req.Body, 4) + + _, err := pkghttputil.ReadRequestBodyWithPrealloc(req) + require.Error(t, err) + var maxErr *http.MaxBytesError + require.ErrorAs(t, err, &maxErr) +} + +func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.True(t, wrote) + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.False(t, wrote) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} + +func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("fallback_written_should_not_downgrade", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, true)) + }) + + t.Run("context_nil_should_not_downgrade", func(t *testing.T) { + require.False(t, shouldLogOpenAIForwardFailureAsWarn(nil, false)) + }) + + t.Run("response_not_written_should_not_downgrade", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, false)) + }) + + t.Run("response_already_written_should_downgrade", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusForbidden, "already written") + require.True(t, shouldLogOpenAIForwardFailureAsWarn(c, false)) + }) +} + +func TestOpenAIRecoverResponsesPanic_WritesFallbackResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + streamStarted := false + require.NotPanics(t, func() { + func() { + defer h.recoverResponsesPanic(c, &streamStarted) + panic("test panic") + }() + }) + + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + streamStarted := false + require.NotPanics(t, func() { + func() { + defer h.recoverResponsesPanic(c, &streamStarted) + }() + }) + + require.False(t, c.Writer.Written()) + assert.Equal(t, "", w.Body.String()) +} + +func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + streamStarted := false + require.NotPanics(t, func() { + func() { + defer h.recoverResponsesPanic(c, &streamStarted) + panic("test panic") + }() + }) + + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} + +func TestOpenAIMissingResponsesDependencies(t *testing.T) { + t.Run("nil_handler", func(t *testing.T) { + var h *OpenAIGatewayHandler + require.Equal(t, []string{"handler"}, h.missingResponsesDependencies()) + }) + + t.Run("all_dependencies_missing", func(t *testing.T) { + h := &OpenAIGatewayHandler{} + require.Equal(t, + []string{"gatewayService", "billingCacheService", "apiKeyService", "concurrencyHelper"}, + h.missingResponsesDependencies(), + ) + }) + + t.Run("all_dependencies_present", func(t *testing.T) { + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{ + concurrencyService: &service.ConcurrencyService{}, + }, + } + require.Empty(t, h.missingResponsesDependencies()) + }) +} + +func TestOpenAIEnsureResponsesDependencies(t *testing.T) { + t.Run("missing_dependencies_returns_503", func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + ok := h.ensureResponsesDependencies(c, nil) + + require.False(t, ok) + require.Equal(t, http.StatusServiceUnavailable, w.Code) + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, exists := parsed["error"].(map[string]any) + require.True(t, exists) + assert.Equal(t, "api_error", errorObj["type"]) + assert.Equal(t, "Service temporarily unavailable", errorObj["message"]) + }) + + t.Run("already_written_response_not_overridden", func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + ok := h.ensureResponsesDependencies(c, nil) + + require.False(t, ok) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) + }) + + t.Run("dependencies_ready_returns_true_and_no_write", func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{ + concurrencyService: &service.ConcurrencyService{}, + }, + } + ok := h.ensureResponsesDependencies(c, nil) + + require.True(t, ok) + require.False(t, c.Writer.Written()) + assert.Equal(t, "", w.Body.String()) + }) +} + +func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`)) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 10, + GroupID: &groupID, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + // 故意使用未初始化依赖,验证快速失败而不是崩溃。 + h := &OpenAIGatewayHandler{} + require.NotPanics(t, func() { + h.Responses(c) + }) + + require.Equal(t, http.StatusServiceUnavailable, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "api_error", errorObj["type"]) + assert.Equal(t, "Service temporarily unavailable", errorObj["message"]) +} + +func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(`{"model":"gpt-5"}`)) + c.Request.Header.Set("Content-Type", "application/json") + + h := &OpenAIGatewayHandler{} + h.Responses(c) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c)) +} + +func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader( + `{"model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456","input":[{"type":"input_text","text":"hello"}]}`, + )) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 101, + GroupID: &groupID, + User: &service.User{ID: 1}, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + h.Responses(c) + + require.Equal(t, http.StatusBadRequest, w.Code) + require.Contains(t, w.Body.String(), "previous_response_id must be a response.id") +} + +func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil) + c.Request.Header.Set("Upgrade", "websocket") + c.Request.Header.Set("Connection", "Upgrade") + + h := &OpenAIGatewayHandler{} + h.ResponsesWebSocket(c) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c)) +} + +func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + h.ResponsesWebSocket(c) + + require.Equal(t, http.StatusUpgradeRequired, w.Code) + require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c)) +} + +func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`, + )) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, _, err = clientConn.Read(readCtx) + cancelRead() + require.Error(t, err) + var closeErr coderws.CloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code) + require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id") +} + +func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) { + gin.SetMode(gin.TestMode) + + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return false, errors.New("user slot unavailable") + }, + } + h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache) + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_123"}`, + )) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, _, err = clientConn.Read(readCtx) + cancelRead() + require.Error(t, err) + var closeErr coderws.CloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusInternalError, closeErr.Code) + require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot") +} + +func TestSetOpenAIClientTransportHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + setOpenAIClientTransportHTTP(c) + require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c)) +} + +func TestSetOpenAIClientTransportWS(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + setOpenAIClientTransportWS(c) + require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c)) +} + +// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性 +func TestOpenAIHandler_GjsonExtraction(t *testing.T) { + tests := []struct { + name string + body string + wantModel string + wantStream bool + }{ + {"正常提取", `{"model":"gpt-4","stream":true,"input":"hello"}`, "gpt-4", true}, + {"stream false", `{"model":"gpt-4","stream":false}`, "gpt-4", false}, + {"无 stream 字段", `{"model":"gpt-4"}`, "gpt-4", false}, + {"model 缺失", `{"stream":true}`, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := []byte(tt.body) + modelResult := gjson.GetBytes(body, "model") + model := "" + if modelResult.Type == gjson.String { + model = modelResult.String() + } + stream := gjson.GetBytes(body, "stream").Bool() + require.Equal(t, tt.wantModel, model) + require.Equal(t, tt.wantStream, stream) + }) + } +} + +// TestOpenAIHandler_GjsonValidation 验证修复后的 JSON 合法性和类型校验 +func TestOpenAIHandler_GjsonValidation(t *testing.T) { + // 非法 JSON 被 gjson.ValidBytes 拦截 + require.False(t, gjson.ValidBytes([]byte(`{invalid json`))) + + // model 为数字 → 类型不是 gjson.String,应被拒绝 + body := []byte(`{"model":123}`) + modelResult := gjson.GetBytes(body, "model") + require.True(t, modelResult.Exists()) + require.NotEqual(t, gjson.String, modelResult.Type) + + // model 为 null → 类型不是 gjson.String,应被拒绝 + body2 := []byte(`{"model":null}`) + modelResult2 := gjson.GetBytes(body2, "model") + require.True(t, modelResult2.Exists()) + require.NotEqual(t, gjson.String, modelResult2.Type) + + // stream 为 string → 类型既不是 True 也不是 False,应被拒绝 + body3 := []byte(`{"model":"gpt-4","stream":"true"}`) + streamResult := gjson.GetBytes(body3, "stream") + require.True(t, streamResult.Exists()) + require.NotEqual(t, gjson.True, streamResult.Type) + require.NotEqual(t, gjson.False, streamResult.Type) + + // stream 为 int → 同上 + body4 := []byte(`{"model":"gpt-4","stream":1}`) + streamResult2 := gjson.GetBytes(body4, "stream") + require.True(t, streamResult2.Exists()) + require.NotEqual(t, gjson.True, streamResult2.Type) + require.NotEqual(t, gjson.False, streamResult2.Type) +} + +// TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑 +func TestOpenAIHandler_InstructionsInjection(t *testing.T) { + // 测试 1:无 instructions → 注入 + body := []byte(`{"model":"gpt-4"}`) + existing := gjson.GetBytes(body, "instructions").String() + require.Empty(t, existing) + newBody, err := sjson.SetBytes(body, "instructions", "test instruction") + require.NoError(t, err) + require.Equal(t, "test instruction", gjson.GetBytes(newBody, "instructions").String()) + + // 测试 2:已有 instructions → 不覆盖 + body2 := []byte(`{"model":"gpt-4","instructions":"existing"}`) + existing2 := gjson.GetBytes(body2, "instructions").String() + require.Equal(t, "existing", existing2) + + // 测试 3:空白 instructions → 注入 + body3 := []byte(`{"model":"gpt-4","instructions":" "}`) + existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String()) + require.Empty(t, existing3) + + // 测试 4:sjson.SetBytes 返回错误时不应 panic + // 正常 JSON 不会产生 sjson 错误,验证返回值被正确处理 + validBody := []byte(`{"model":"gpt-4"}`) + result, setErr := sjson.SetBytes(validBody, "instructions", "hello") + require.NoError(t, setErr) + require.True(t, gjson.ValidBytes(result)) +} + +func newOpenAIHandlerForPreviousResponseIDValidation(t *testing.T, cache *concurrencyCacheMock) *OpenAIGatewayHandler { + t.Helper() + if cache == nil { + cache = &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + } + return &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second), + } +} + +func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject middleware.AuthSubject) *httptest.Server { + t.Helper() + groupID := int64(2) + apiKey := &service.APIKey{ + ID: 101, + GroupID: &groupID, + User: &service.User{ID: subject.UserID}, + } + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), subject) + c.Next() + }) + router.GET("/openai/v1/responses", h.ResponsesWebSocket) + return httptest.NewServer(router) +} diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index 36ffde63..2f53d655 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -41,9 +41,8 @@ const ( ) type opsErrorLogJob struct { - ops *service.OpsService - entry *service.OpsInsertErrorLogInput - requestBody []byte + ops *service.OpsService + entry *service.OpsInsertErrorLogInput } var ( @@ -58,6 +57,7 @@ var ( opsErrorLogEnqueued atomic.Int64 opsErrorLogDropped atomic.Int64 opsErrorLogProcessed atomic.Int64 + opsErrorLogSanitized atomic.Int64 opsErrorLogLastDropLogAt atomic.Int64 @@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() { } }() ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout) - _ = job.ops.RecordError(ctx, job.entry, job.requestBody) + _ = job.ops.RecordError(ctx, job.entry, nil) cancel() opsErrorLogProcessed.Add(1) }() @@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() { } } -func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) { +func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) { if ops == nil || entry == nil { return } @@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo } select { - case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}: + case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}: opsErrorLogQueueLen.Add(1) opsErrorLogEnqueued.Add(1) default: @@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 { return opsErrorLogProcessed.Load() } +func OpsErrorLogSanitizedTotal() int64 { + return opsErrorLogSanitized.Load() +} + func maybeLogOpsErrorLogDrop() { now := time.Now().Unix() @@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() { queueCap := OpsErrorLogQueueCapacity() log.Printf( - "[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)", + "[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d sanitized_total=%d)", queued, queueCap, opsErrorLogEnqueued.Load(), opsErrorLogDropped.Load(), opsErrorLogProcessed.Load(), + opsErrorLogSanitized.Load(), ) } @@ -255,18 +260,49 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody if c == nil { return } + model = strings.TrimSpace(model) c.Set(opsModelKey, model) c.Set(opsStreamKey, stream) if len(requestBody) > 0 { c.Set(opsRequestBodyKey, requestBody) } + if c.Request != nil && model != "" { + ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model) + c.Request = c.Request.WithContext(ctx) + } } -func setOpsSelectedAccount(c *gin.Context, accountID int64) { +func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) { + if c == nil || entry == nil { + return + } + v, ok := c.Get(opsRequestBodyKey) + if !ok { + return + } + raw, ok := v.([]byte) + if !ok || len(raw) == 0 { + return + } + entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw) + opsErrorLogSanitized.Add(1) +} + +func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) { if c == nil || accountID <= 0 { return } c.Set(opsAccountIDKey, accountID) + if c.Request != nil { + ctx := context.WithValue(c.Request.Context(), ctxkey.AccountID, accountID) + if len(platform) > 0 { + p := strings.TrimSpace(platform[0]) + if p != "" { + ctx = context.WithValue(ctx, ctxkey.Platform, p) + } + } + c.Request = c.Request.WithContext(ctx) + } } type opsCaptureWriter struct { @@ -275,6 +311,35 @@ type opsCaptureWriter struct { buf bytes.Buffer } +const opsCaptureWriterLimit = 64 * 1024 + +var opsCaptureWriterPool = sync.Pool{ + New: func() any { + return &opsCaptureWriter{limit: opsCaptureWriterLimit} + }, +} + +func acquireOpsCaptureWriter(rw gin.ResponseWriter) *opsCaptureWriter { + w, ok := opsCaptureWriterPool.Get().(*opsCaptureWriter) + if !ok || w == nil { + w = &opsCaptureWriter{} + } + w.ResponseWriter = rw + w.limit = opsCaptureWriterLimit + w.buf.Reset() + return w +} + +func releaseOpsCaptureWriter(w *opsCaptureWriter) { + if w == nil { + return + } + w.ResponseWriter = nil + w.limit = opsCaptureWriterLimit + w.buf.Reset() + opsCaptureWriterPool.Put(w) +} + func (w *opsCaptureWriter) Write(b []byte) (int, error) { if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit { remaining := w.limit - w.buf.Len() @@ -306,7 +371,16 @@ func (w *opsCaptureWriter) WriteString(s string) (int, error) { // - Streaming errors after the response has started (SSE) may still need explicit logging. func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { return func(c *gin.Context) { - w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024} + originalWriter := c.Writer + w := acquireOpsCaptureWriter(originalWriter) + defer func() { + // Restore the original writer before returning so outer middlewares + // don't observe a pooled wrapper that has been released. + if c.Writer == w { + c.Writer = originalWriter + } + releaseOpsCaptureWriter(w) + }() c.Writer = w c.Next() @@ -507,6 +581,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { RetryCount: 0, CreatedAt: time.Now(), } + applyOpsLatencyFieldsFromContext(c, entry) if apiKey != nil { entry.APIKeyID = &apiKey.ID @@ -528,22 +603,31 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { entry.ClientIP = &clientIP } - var requestBody []byte - if v, ok := c.Get(opsRequestBodyKey); ok { - if b, ok := v.([]byte); ok && len(b) > 0 { - requestBody = b - } - } // Store request headers/body only when an upstream error occurred to keep overhead minimal. entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) + attachOpsRequestBodyToEntry(c, entry) - enqueueOpsErrorLog(ops, entry, requestBody) + // Skip logging if a passthrough rule with skip_monitoring=true matched. + if v, ok := c.Get(service.OpsSkipPassthroughKey); ok { + if skip, _ := v.(bool); skip { + return + } + } + + enqueueOpsErrorLog(ops, entry) return } body := w.buf.Bytes() parsed := parseOpsErrorResponse(body) + // Skip logging if a passthrough rule with skip_monitoring=true matched. + if v, ok := c.Get(service.OpsSkipPassthroughKey); ok { + if skip, _ := v.(bool); skip { + return + } + } + // Skip logging if the error should be filtered based on settings if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) { return @@ -578,8 +662,10 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { requestID = c.Writer.Header().Get("x-request-id") } - phase := classifyOpsPhase(parsed.ErrorType, parsed.Message, parsed.Code) - isBusinessLimited := classifyOpsIsBusinessLimited(parsed.ErrorType, phase, parsed.Code, status, parsed.Message) + normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code) + + phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code) + isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message) errorOwner := classifyOpsErrorOwner(phase, parsed.Message) errorSource := classifyOpsErrorSource(phase, parsed.Message) @@ -601,8 +687,8 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { UserAgent: c.GetHeader("User-Agent"), ErrorPhase: phase, - ErrorType: normalizeOpsErrorType(parsed.ErrorType, parsed.Code), - Severity: classifyOpsSeverity(parsed.ErrorType, status), + ErrorType: normalizedType, + Severity: classifyOpsSeverity(normalizedType, status), StatusCode: status, IsBusinessLimited: isBusinessLimited, IsCountTokens: isCountTokensRequest(c), @@ -614,10 +700,11 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { ErrorSource: errorSource, ErrorOwner: errorOwner, - IsRetryable: classifyOpsIsRetryable(parsed.ErrorType, status), + IsRetryable: classifyOpsIsRetryable(normalizedType, status), RetryCount: 0, CreatedAt: time.Now(), } + applyOpsLatencyFieldsFromContext(c, entry) // Capture upstream error context set by gateway services (if present). // This does NOT affect the client response; it enriches Ops troubleshooting data. @@ -693,17 +780,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { entry.ClientIP = &clientIP } - var requestBody []byte - if v, ok := c.Get(opsRequestBodyKey); ok { - if b, ok := v.([]byte); ok && len(b) > 0 { - requestBody = b - } - } // Persist only a minimal, whitelisted set of request headers to improve retry fidelity. // Do NOT store Authorization/Cookie/etc. entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) + attachOpsRequestBodyToEntry(c, entry) - enqueueOpsErrorLog(ops, entry, requestBody) + enqueueOpsErrorLog(ops, entry) } } @@ -746,6 +828,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string { return &s } +func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) { + if c == nil || entry == nil { + return + } + entry.AuthLatencyMs = getContextLatencyMs(c, service.OpsAuthLatencyMsKey) + entry.RoutingLatencyMs = getContextLatencyMs(c, service.OpsRoutingLatencyMsKey) + entry.UpstreamLatencyMs = getContextLatencyMs(c, service.OpsUpstreamLatencyMsKey) + entry.ResponseLatencyMs = getContextLatencyMs(c, service.OpsResponseLatencyMsKey) + entry.TimeToFirstTokenMs = getContextLatencyMs(c, service.OpsTimeToFirstTokenMsKey) +} + +func getContextLatencyMs(c *gin.Context, key string) *int64 { + if c == nil || strings.TrimSpace(key) == "" { + return nil + } + v, ok := c.Get(key) + if !ok { + return nil + } + var ms int64 + switch t := v.(type) { + case int: + ms = int64(t) + case int32: + ms = int64(t) + case int64: + ms = t + case float64: + ms = int64(t) + default: + return nil + } + if ms < 0 { + return nil + } + return &ms +} + type parsedOpsError struct { ErrorType string Message string @@ -821,8 +941,29 @@ func guessPlatformFromPath(path string) string { } } +// isKnownOpsErrorType returns true if t is a recognized error type used by the +// ops classification pipeline. Upstream proxies sometimes return garbage values +// (e.g. the Go-serialized literal "") which would pollute phase/severity +// classification if accepted blindly. +func isKnownOpsErrorType(t string) bool { + switch t { + case "invalid_request_error", + "authentication_error", + "rate_limit_error", + "billing_error", + "subscription_error", + "upstream_error", + "overloaded_error", + "api_error", + "not_found_error", + "forbidden_error": + return true + } + return false +} + func normalizeOpsErrorType(errType string, code string) string { - if errType != "" { + if errType != "" && isKnownOpsErrorType(errType) { return errType } switch strings.TrimSpace(code) { diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go new file mode 100644 index 00000000..679dd4ce --- /dev/null +++ b/backend/internal/handler/ops_error_logger_test.go @@ -0,0 +1,276 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func resetOpsErrorLoggerStateForTest(t *testing.T) { + t.Helper() + + opsErrorLogMu.Lock() + ch := opsErrorLogQueue + opsErrorLogQueue = nil + opsErrorLogStopping = true + opsErrorLogMu.Unlock() + + if ch != nil { + close(ch) + } + opsErrorLogWorkersWg.Wait() + + opsErrorLogOnce = sync.Once{} + opsErrorLogStopOnce = sync.Once{} + opsErrorLogWorkersWg = sync.WaitGroup{} + opsErrorLogMu = sync.RWMutex{} + opsErrorLogStopping = false + + opsErrorLogQueueLen.Store(0) + opsErrorLogEnqueued.Store(0) + opsErrorLogDropped.Store(0) + opsErrorLogProcessed.Store(0) + opsErrorLogSanitized.Store(0) + opsErrorLogLastDropLogAt.Store(0) + + opsErrorLogShutdownCh = make(chan struct{}) + opsErrorLogShutdownOnce = sync.Once{} + opsErrorLogDrained.Store(false) +} + +func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`) + setOpsRequestContext(c, "claude-3", false, raw) + + entry := &service.OpsInsertErrorLogInput{} + attachOpsRequestBodyToEntry(c, entry) + + require.NotNil(t, entry.RequestBodyBytes) + require.Equal(t, len(raw), *entry.RequestBodyBytes) + require.NotNil(t, entry.RequestBodyJSON) + require.NotContains(t, *entry.RequestBodyJSON, "secret-token") + require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]") + require.Equal(t, int64(1), OpsErrorLogSanitizedTotal()) +} + +func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + raw := []byte("not-json") + setOpsRequestContext(c, "claude-3", false, raw) + + entry := &service.OpsInsertErrorLogInput{} + attachOpsRequestBodyToEntry(c, entry) + + require.Nil(t, entry.RequestBodyJSON) + require.NotNil(t, entry.RequestBodyBytes) + require.Equal(t, len(raw), *entry.RequestBodyBytes) + require.False(t, entry.RequestBodyTruncated) + require.Equal(t, int64(1), OpsErrorLogSanitizedTotal()) +} + +func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + + // 禁止 enqueueOpsErrorLog 触发 workers,使用测试队列验证满队列降级。 + opsErrorLogOnce.Do(func() {}) + + opsErrorLogMu.Lock() + opsErrorLogQueue = make(chan opsErrorLogJob, 1) + opsErrorLogMu.Unlock() + + ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"} + + enqueueOpsErrorLog(ops, entry) + enqueueOpsErrorLog(ops, entry) + + require.Equal(t, int64(1), OpsErrorLogEnqueuedTotal()) + require.Equal(t, int64(1), OpsErrorLogDroppedTotal()) + require.Equal(t, int64(1), OpsErrorLogQueueLength()) +} + +func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + gin.SetMode(gin.TestMode) + + entry := &service.OpsInsertErrorLogInput{} + attachOpsRequestBodyToEntry(nil, entry) + attachOpsRequestBodyToEntry(&gin.Context{}, nil) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + // 无请求体 key + attachOpsRequestBodyToEntry(c, entry) + require.Nil(t, entry.RequestBodyJSON) + require.Nil(t, entry.RequestBodyBytes) + require.False(t, entry.RequestBodyTruncated) + + // 错误类型 + c.Set(opsRequestBodyKey, "not-bytes") + attachOpsRequestBodyToEntry(c, entry) + require.Nil(t, entry.RequestBodyJSON) + require.Nil(t, entry.RequestBodyBytes) + + // 空 bytes + c.Set(opsRequestBodyKey, []byte{}) + attachOpsRequestBodyToEntry(c, entry) + require.Nil(t, entry.RequestBodyJSON) + require.Nil(t, entry.RequestBodyBytes) + + require.Equal(t, int64(0), OpsErrorLogSanitizedTotal()) +} + +func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + + ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"} + + // nil 入参分支 + enqueueOpsErrorLog(nil, entry) + enqueueOpsErrorLog(ops, nil) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) + + // shutdown 分支 + close(opsErrorLogShutdownCh) + enqueueOpsErrorLog(ops, entry) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) + + // stopping 分支 + resetOpsErrorLoggerStateForTest(t) + opsErrorLogMu.Lock() + opsErrorLogStopping = true + opsErrorLogMu.Unlock() + enqueueOpsErrorLog(ops, entry) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) + + // queue nil 分支(防止启动 worker 干扰) + resetOpsErrorLoggerStateForTest(t) + opsErrorLogOnce.Do(func() {}) + opsErrorLogMu.Lock() + opsErrorLogQueue = nil + opsErrorLogMu.Unlock() + enqueueOpsErrorLog(ops, entry) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) +} + +func TestOpsCaptureWriterPool_ResetOnRelease(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/test", nil) + + writer := acquireOpsCaptureWriter(c.Writer) + require.NotNil(t, writer) + _, err := writer.buf.WriteString("temp-error-body") + require.NoError(t, err) + + releaseOpsCaptureWriter(writer) + + reused := acquireOpsCaptureWriter(c.Writer) + defer releaseOpsCaptureWriter(reused) + + require.Zero(t, reused.buf.Len(), "writer should be reset before reuse") +} + +func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(middleware2.Recovery()) + r.Use(middleware2.RequestLogger()) + r.Use(middleware2.Logger()) + r.GET("/v1/messages", OpsErrorLoggerMiddleware(nil), func(c *gin.Context) { + c.Status(http.StatusNoContent) + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/messages", nil) + + require.NotPanics(t, func() { + r.ServeHTTP(rec, req) + }) + require.Equal(t, http.StatusNoContent, rec.Code) +} + +func TestIsKnownOpsErrorType(t *testing.T) { + known := []string{ + "invalid_request_error", + "authentication_error", + "rate_limit_error", + "billing_error", + "subscription_error", + "upstream_error", + "overloaded_error", + "api_error", + "not_found_error", + "forbidden_error", + } + for _, k := range known { + require.True(t, isKnownOpsErrorType(k), "expected known: %s", k) + } + + unknown := []string{"", "null", "", "random_error", "some_new_type", "\u003e"} + for _, u := range unknown { + require.False(t, isKnownOpsErrorType(u), "expected unknown: %q", u) + } +} + +func TestNormalizeOpsErrorType(t *testing.T) { + tests := []struct { + name string + errType string + code string + want string + }{ + // Known types pass through. + {"known invalid_request_error", "invalid_request_error", "", "invalid_request_error"}, + {"known rate_limit_error", "rate_limit_error", "", "rate_limit_error"}, + {"known upstream_error", "upstream_error", "", "upstream_error"}, + + // Unknown/garbage types are rejected and fall through to code-based or default. + {"nil literal from upstream", "", "", "api_error"}, + {"null string", "null", "", "api_error"}, + {"random string", "something_weird", "", "api_error"}, + + // Unknown type but known code still maps correctly. + {"nil with INSUFFICIENT_BALANCE code", "", "INSUFFICIENT_BALANCE", "billing_error"}, + {"nil with USAGE_LIMIT_EXCEEDED code", "", "USAGE_LIMIT_EXCEEDED", "subscription_error"}, + + // Empty type falls through to code-based mapping. + {"empty type with balance code", "", "INSUFFICIENT_BALANCE", "billing_error"}, + {"empty type with subscription code", "", "SUBSCRIPTION_NOT_FOUND", "subscription_error"}, + {"empty type no code", "", "", "api_error"}, + + // Known type overrides conflicting code-based mapping. + {"known type overrides conflicting code", "rate_limit_error", "INSUFFICIENT_BALANCE", "rate_limit_error"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeOpsErrorType(tt.errType, tt.code) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 2029f116..a48eaf31 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -50,7 +50,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + SoraClientEnabled: settings.SoraClientEnabled, Version: h.version, }) } diff --git a/backend/internal/handler/sora_client_handler.go b/backend/internal/handler/sora_client_handler.go new file mode 100644 index 00000000..80acc833 --- /dev/null +++ b/backend/internal/handler/sora_client_handler.go @@ -0,0 +1,979 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +const ( + // 上游模型缓存 TTL + modelCacheTTL = 1 * time.Hour // 上游获取成功 + modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地) +) + +// SoraClientHandler 处理 Sora 客户端 API 请求。 +type SoraClientHandler struct { + genService *service.SoraGenerationService + quotaService *service.SoraQuotaService + s3Storage *service.SoraS3Storage + soraGatewayService *service.SoraGatewayService + gatewayService *service.GatewayService + mediaStorage *service.SoraMediaStorage + apiKeyService *service.APIKeyService + + // 上游模型缓存 + modelCacheMu sync.RWMutex + cachedFamilies []service.SoraModelFamily + modelCacheTime time.Time + modelCacheUpstream bool // 是否来自上游(决定 TTL) +} + +// NewSoraClientHandler 创建 Sora 客户端 Handler。 +func NewSoraClientHandler( + genService *service.SoraGenerationService, + quotaService *service.SoraQuotaService, + s3Storage *service.SoraS3Storage, + soraGatewayService *service.SoraGatewayService, + gatewayService *service.GatewayService, + mediaStorage *service.SoraMediaStorage, + apiKeyService *service.APIKeyService, +) *SoraClientHandler { + return &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + s3Storage: s3Storage, + soraGatewayService: soraGatewayService, + gatewayService: gatewayService, + mediaStorage: mediaStorage, + apiKeyService: apiKeyService, + } +} + +// GenerateRequest 生成请求。 +type GenerateRequest struct { + Model string `json:"model" binding:"required"` + Prompt string `json:"prompt" binding:"required"` + MediaType string `json:"media_type"` // video / image,默认 video + VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3) + ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL) + APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID +} + +// Generate 异步生成 — 创建 pending 记录后立即返回。 +// POST /api/v1/sora/generate +func (h *SoraClientHandler) Generate(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + var req GenerateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error()) + return + } + + if req.MediaType == "" { + req.MediaType = "video" + } + req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount) + + // 并发数检查(最多 3 个) + activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + if activeCount >= 3 { + response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") + return + } + + // 配额检查(粗略检查,实际文件大小在上传后才知道) + if h.quotaService != nil { + if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil { + var quotaErr *service.QuotaExceededError + if errors.As(err, "aErr) { + response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") + return + } + response.Error(c, http.StatusForbidden, err.Error()) + return + } + } + + // 获取 API Key ID 和 Group ID + var apiKeyID *int64 + var groupID *int64 + + if req.APIKeyID != nil && h.apiKeyService != nil { + // 前端传递了 api_key_id,需要校验 + apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID) + if err != nil { + response.Error(c, http.StatusBadRequest, "API Key 不存在") + return + } + if apiKey.UserID != userID { + response.Error(c, http.StatusForbidden, "API Key 不属于当前用户") + return + } + if apiKey.Status != service.StatusAPIKeyActive { + response.Error(c, http.StatusForbidden, "API Key 不可用") + return + } + apiKeyID = &apiKey.ID + groupID = apiKey.GroupID + } else if id, ok := c.Get("api_key_id"); ok { + // 兼容 API Key 认证路径(/sora/v1/ 网关路由) + if v, ok := id.(int64); ok { + apiKeyID = &v + } + } + + gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType) + if err != nil { + if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) { + response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") + return + } + response.ErrorFrom(c, err) + return + } + + // 启动后台异步生成 goroutine + go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount) + + response.Success(c, gin.H{ + "generation_id": gen.ID, + "status": gen.Status, + }) +} + +// processGeneration 后台异步执行 Sora 生成任务。 +// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。 +func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + // 标记为生成中 + if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil { + if errors.Is(err, service.ErrSoraGenerationStateConflict) { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID) + return + } + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err) + return + } + + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d", + genID, + userID, + groupIDForLog(groupID), + model, + mediaType, + videoCount, + strings.TrimSpace(imageInput) != "", + len(strings.TrimSpace(prompt)), + ) + + // 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底 + if groupID == nil { + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) + } + + if h.gatewayService == nil { + _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化") + return + } + + // 选择 Sora 账号 + account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model) + if err != nil { + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v", + genID, + userID, + groupIDForLog(groupID), + model, + err, + ) + _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error()) + return + } + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s", + genID, + userID, + groupIDForLog(groupID), + model, + account.ID, + account.Name, + account.Platform, + account.Type, + ) + + // 构建 chat completions 请求体(非流式) + body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount)) + + if h.soraGatewayService == nil { + _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化") + return + } + + // 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL) + recorder := httptest.NewRecorder() + mockGinCtx, _ := gin.CreateTestContext(recorder) + mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil) + + // 调用 Forward(非流式) + result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false) + if err != nil { + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v", + genID, + account.ID, + model, + recorder.Code, + trimForLog(recorder.Body.String(), 400), + err, + ) + // 检查是否已取消 + gen, _ := h.genService.GetByID(ctx, genID, userID) + if gen != nil && gen.Status == service.SoraGenStatusCancelled { + return + } + _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error()) + return + } + + // 提取媒体 URL(优先从 ForwardResult,其次从响应体解析) + mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder) + if mediaURL == "" { + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s", + genID, + account.ID, + model, + recorder.Code, + trimForLog(recorder.Body.String(), 400), + ) + _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL") + return + } + + // 检查任务是否已被取消 + gen, _ := h.genService.GetByID(ctx, genID, userID) + if gen != nil && gen.Status == service.SoraGenStatusCancelled { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID) + return + } + + // 三层降级存储:S3 → 本地 → 上游临时 URL + storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs) + + usageAdded := false + if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil { + if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil { + h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) + var quotaErr *service.QuotaExceededError + if errors.As(err, "aErr) { + _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间") + return + } + _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error()) + return + } + usageAdded = true + } + + // 存储完成后再做一次取消检查,防止取消被 completed 覆盖。 + gen, _ = h.genService.GetByID(ctx, genID, userID) + if gen != nil && gen.Status == service.SoraGenStatusCancelled { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID) + h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) + if usageAdded && h.quotaService != nil { + _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) + } + return + } + + // 标记完成 + if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil { + if errors.Is(err, service.ErrSoraGenerationStateConflict) { + h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) + if usageAdded && h.quotaService != nil { + _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) + } + return + } + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err) + return + } + + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize) +} + +// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。 +func (h *SoraClientHandler) storeMediaWithDegradation( + ctx context.Context, userID int64, mediaType string, + mediaURL string, mediaURLs []string, +) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) { + urls := mediaURLs + if len(urls) == 0 { + urls = []string{mediaURL} + } + + // 第一层:尝试 S3 + if h.s3Storage != nil && h.s3Storage.Enabled(ctx) { + keys := make([]string, 0, len(urls)) + var totalSize int64 + allOK := true + for _, u := range urls { + key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u) + if err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err) + allOK = false + // 清理已上传的文件 + if len(keys) > 0 { + _ = h.s3Storage.DeleteObjects(ctx, keys) + } + break + } + keys = append(keys, key) + totalSize += size + } + if allOK && len(keys) > 0 { + accessURLs := make([]string, 0, len(keys)) + for _, key := range keys { + accessURL, err := h.s3Storage.GetAccessURL(ctx, key) + if err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err) + _ = h.s3Storage.DeleteObjects(ctx, keys) + allOK = false + break + } + accessURLs = append(accessURLs, accessURL) + } + if allOK && len(accessURLs) > 0 { + return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize + } + } + } + + // 第二层:尝试本地存储 + if h.mediaStorage != nil && h.mediaStorage.Enabled() { + storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls) + if err == nil && len(storedPaths) > 0 { + firstPath := storedPaths[0] + totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths) + if sizeErr != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr) + } + return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize + } + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err) + } + + // 第三层:保留上游临时 URL + return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0 +} + +// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。 +func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte { + body := map[string]any{ + "model": model, + "messages": []map[string]string{ + {"role": "user", "content": prompt}, + }, + "stream": false, + } + if imageInput != "" { + body["image_input"] = imageInput + } + if videoCount > 1 { + body["video_count"] = videoCount + } + b, _ := json.Marshal(body) + return b +} + +func normalizeVideoCount(mediaType string, videoCount int) int { + if mediaType != "video" { + return 1 + } + if videoCount <= 0 { + return 1 + } + if videoCount > 3 { + return 3 + } + return videoCount +} + +// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。 +// OAuth 路径:ForwardResult.MediaURL 已填充。 +// APIKey 路径:需从响应体解析 media_url / media_urls 字段。 +func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) { + // 优先从 ForwardResult 获取(OAuth 路径) + if result != nil && result.MediaURL != "" { + // 尝试从响应体获取完整 URL 列表 + if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { + return urls[0], urls + } + return result.MediaURL, []string{result.MediaURL} + } + + // 从响应体解析(APIKey 路径) + if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { + return urls[0], urls + } + + return "", nil +} + +// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。 +func parseMediaURLsFromBody(body []byte) []string { + if len(body) == 0 { + return nil + } + var resp map[string]any + if err := json.Unmarshal(body, &resp); err != nil { + return nil + } + + // 优先 media_urls(多图数组) + if rawURLs, ok := resp["media_urls"]; ok { + if arr, ok := rawURLs.([]any); ok && len(arr) > 0 { + urls := make([]string, 0, len(arr)) + for _, item := range arr { + if s, ok := item.(string); ok && s != "" { + urls = append(urls, s) + } + } + if len(urls) > 0 { + return urls + } + } + } + + // 回退到 media_url(单个 URL) + if url, ok := resp["media_url"].(string); ok && url != "" { + return []string{url} + } + + return nil +} + +// ListGenerations 查询生成记录列表。 +// GET /api/v1/sora/generations +func (h *SoraClientHandler) ListGenerations(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + params := service.SoraGenerationListParams{ + UserID: userID, + Status: c.Query("status"), + StorageType: c.Query("storage_type"), + MediaType: c.Query("media_type"), + Page: page, + PageSize: pageSize, + } + + gens, total, err := h.genService.List(c.Request.Context(), params) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // 为 S3 记录动态生成预签名 URL + for _, gen := range gens { + _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) + } + + response.Success(c, gin.H{ + "data": gens, + "total": total, + "page": page, + }) +} + +// GetGeneration 查询生成记录详情。 +// GET /api/v1/sora/generations/:id +func (h *SoraClientHandler) GetGeneration(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) + response.Success(c, gen) +} + +// DeleteGeneration 删除生成记录。 +// DELETE /api/v1/sora/generations/:id +func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + // 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。 + if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil { + paths := gen.MediaURLs + if len(paths) == 0 && gen.MediaURL != "" { + paths = []string{gen.MediaURL} + } + if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err) + } + } + + if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + response.Success(c, gin.H{"message": "已删除"}) +} + +// GetQuota 查询用户存储配额。 +// GET /api/v1/sora/quota +func (h *SoraClientHandler) GetQuota(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + if h.quotaService == nil { + response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"}) + return + } + + quota, err := h.quotaService.GetQuota(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, quota) +} + +// CancelGeneration 取消生成任务。 +// POST /api/v1/sora/generations/:id/cancel +func (h *SoraClientHandler) CancelGeneration(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + // 权限校验 + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + _ = gen + + if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil { + if errors.Is(err, service.ErrSoraGenerationNotActive) { + response.Error(c, http.StatusConflict, "任务已结束,无法取消") + return + } + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + + response.Success(c, gin.H{"message": "已取消"}) +} + +// SaveToStorage 手动保存 upstream 记录到 S3。 +// POST /api/v1/sora/generations/:id/save +func (h *SoraClientHandler) SaveToStorage(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + if gen.StorageType != service.SoraStorageTypeUpstream { + response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存") + return + } + if gen.MediaURL == "" { + response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") + return + } + + if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) { + response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员") + return + } + + sourceURLs := gen.MediaURLs + if len(sourceURLs) == 0 && gen.MediaURL != "" { + sourceURLs = []string{gen.MediaURL} + } + if len(sourceURLs) == 0 { + response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") + return + } + + uploadedKeys := make([]string, 0, len(sourceURLs)) + accessURLs := make([]string, 0, len(sourceURLs)) + var totalSize int64 + + for _, sourceURL := range sourceURLs { + objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL) + if uploadErr != nil { + if len(uploadedKeys) > 0 { + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + } + var upstreamErr *service.UpstreamDownloadError + if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) { + response.Error(c, http.StatusGone, "媒体链接已过期,无法保存") + return + } + response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error()) + return + } + accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey) + if err != nil { + uploadedKeys = append(uploadedKeys, objectKey) + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error()) + return + } + uploadedKeys = append(uploadedKeys, objectKey) + accessURLs = append(accessURLs, accessURL) + totalSize += fileSize + } + + usageAdded := false + if totalSize > 0 && h.quotaService != nil { + if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil { + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + var quotaErr *service.QuotaExceededError + if errors.As(err, "aErr) { + response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") + return + } + response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error()) + return + } + usageAdded = true + } + + if err := h.genService.UpdateStorageForCompleted( + c.Request.Context(), + id, + accessURLs[0], + accessURLs, + service.SoraStorageTypeS3, + uploadedKeys, + totalSize, + ); err != nil { + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + if usageAdded && h.quotaService != nil { + _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize) + } + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "message": "已保存到 S3", + "object_key": uploadedKeys[0], + "object_keys": uploadedKeys, + }) +} + +// GetStorageStatus 返回存储状态。 +// GET /api/v1/sora/storage-status +func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) { + s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context()) + s3Healthy := false + if s3Enabled { + s3Healthy = h.s3Storage.IsHealthy(c.Request.Context()) + } + localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled() + response.Success(c, gin.H{ + "s3_enabled": s3Enabled, + "s3_healthy": s3Healthy, + "local_enabled": localEnabled, + }) +} + +func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) { + switch storageType { + case service.SoraStorageTypeS3: + if h.s3Storage != nil && len(s3Keys) > 0 { + if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err) + } + } + case service.SoraStorageTypeLocal: + if h.mediaStorage != nil && len(localPaths) > 0 { + if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err) + } + } + } +} + +// getUserIDFromContext 从 gin 上下文中提取用户 ID。 +func getUserIDFromContext(c *gin.Context) int64 { + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { + return subject.UserID + } + + if id, ok := c.Get("user_id"); ok { + switch v := id.(type) { + case int64: + return v + case float64: + return int64(v) + case string: + n, _ := strconv.ParseInt(v, 10, 64) + return n + } + } + // 尝试从 JWT claims 获取 + if id, ok := c.Get("userID"); ok { + if v, ok := id.(int64); ok { + return v + } + } + return 0 +} + +func groupIDForLog(groupID *int64) int64 { + if groupID == nil { + return 0 + } + return *groupID +} + +func trimForLog(raw string, maxLen int) string { + trimmed := strings.TrimSpace(raw) + if maxLen <= 0 || len(trimmed) <= maxLen { + return trimmed + } + return trimmed[:maxLen] + "...(truncated)" +} + +// GetModels 获取可用 Sora 模型家族列表。 +// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。 +// GET /api/v1/sora/models +func (h *SoraClientHandler) GetModels(c *gin.Context) { + families := h.getModelFamilies(c.Request.Context()) + response.Success(c, families) +} + +// getModelFamilies 获取模型家族列表(带缓存)。 +func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily { + // 读锁检查缓存 + h.modelCacheMu.RLock() + ttl := modelCacheTTL + if !h.modelCacheUpstream { + ttl = modelCacheFailedTTL + } + if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { + families := h.cachedFamilies + h.modelCacheMu.RUnlock() + return families + } + h.modelCacheMu.RUnlock() + + // 写锁更新缓存 + h.modelCacheMu.Lock() + defer h.modelCacheMu.Unlock() + + // double-check + ttl = modelCacheTTL + if !h.modelCacheUpstream { + ttl = modelCacheFailedTTL + } + if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { + return h.cachedFamilies + } + + // 尝试从上游获取 + families, err := h.fetchUpstreamModels(ctx) + if err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err) + families = service.BuildSoraModelFamilies() + h.cachedFamilies = families + h.modelCacheTime = time.Now() + h.modelCacheUpstream = false + return families + } + + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families)) + h.cachedFamilies = families + h.modelCacheTime = time.Now() + h.modelCacheUpstream = true + return families +} + +// fetchUpstreamModels 从上游 Sora API 获取模型列表。 +func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) { + if h.gatewayService == nil { + return nil, fmt.Errorf("gatewayService 未初始化") + } + + // 设置 ForcePlatform 用于 Sora 账号选择 + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) + + // 选择一个 Sora 账号 + account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s") + if err != nil { + return nil, fmt.Errorf("选择 Sora 账号失败: %w", err) + } + + // 仅支持 API Key 类型账号 + if account.Type != service.AccountTypeAPIKey { + return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type) + } + + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return nil, fmt.Errorf("账号缺少 api_key") + } + + baseURL := account.GetBaseURL() + if baseURL == "" { + return nil, fmt.Errorf("账号缺少 base_url") + } + + // 构建上游模型列表请求 + modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models" + + reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + req.Header.Set("Authorization", "Bearer "+apiKey) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("请求上游失败: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 解析 OpenAI 格式的模型列表 + var modelsResp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + if err := json.Unmarshal(body, &modelsResp); err != nil { + return nil, fmt.Errorf("解析响应失败: %w", err) + } + + if len(modelsResp.Data) == 0 { + return nil, fmt.Errorf("上游返回空模型列表") + } + + // 提取模型 ID + modelIDs := make([]string, 0, len(modelsResp.Data)) + for _, m := range modelsResp.Data { + modelIDs = append(modelIDs, m.ID) + } + + // 转换为模型家族 + families := service.BuildSoraModelFamiliesFromIDs(modelIDs) + if len(families) == 0 { + return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族") + } + + return families, nil +} diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go new file mode 100644 index 00000000..5df7fa0a --- /dev/null +++ b/backend/internal/handler/sora_client_handler_test.go @@ -0,0 +1,3138 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// ==================== Stub: SoraGenerationRepository ==================== + +var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil) + +type stubSoraGenRepo struct { + gens map[int64]*service.SoraGeneration + nextID int64 + createErr error + getErr error + updateErr error + deleteErr error + listErr error + countErr error + countValue int64 + + // 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败 + updateCallCount *int32 + updateFailAfterN int32 + + // 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus + getByIDCallCount int32 + getByIDOverrideAfterN int32 // 0 = 不覆盖 + getByIDOverrideStatus string +} + +func newStubSoraGenRepo() *stubSoraGenRepo { + return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1} +} + +func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error { + if r.createErr != nil { + return r.createErr + } + gen.ID = r.nextID + r.nextID++ + r.gens[gen.ID] = gen + return nil +} +func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) { + if r.getErr != nil { + return nil, r.getErr + } + gen, ok := r.gens[id] + if !ok { + return nil, fmt.Errorf("not found") + } + // 条件性状态覆盖:模拟外部取消等场景 + if r.getByIDOverrideAfterN > 0 { + n := atomic.AddInt32(&r.getByIDCallCount, 1) + if n > r.getByIDOverrideAfterN { + cp := *gen + cp.Status = r.getByIDOverrideStatus + return &cp, nil + } + } + return gen, nil +} +func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error { + // 条件性失败:前 N 次成功,之后失败 + if r.updateCallCount != nil { + n := atomic.AddInt32(r.updateCallCount, 1) + if n > r.updateFailAfterN { + return fmt.Errorf("conditional update error (call #%d)", n) + } + } + if r.updateErr != nil { + return r.updateErr + } + r.gens[gen.ID] = gen + return nil +} +func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error { + if r.deleteErr != nil { + return r.deleteErr + } + delete(r.gens, id) + return nil +} +func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { + if r.listErr != nil { + return nil, 0, r.listErr + } + var result []*service.SoraGeneration + for _, gen := range r.gens { + if gen.UserID != params.UserID { + continue + } + result = append(result, gen) + } + return result, int64(len(result)), nil +} +func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) { + if r.countErr != nil { + return 0, r.countErr + } + return r.countValue, nil +} + +// ==================== 辅助函数 ==================== + +func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler { + genService := service.NewSoraGenerationService(repo, nil, nil) + return &SoraClientHandler{genService: genService} +} + +func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + if body != "" { + c.Request = httptest.NewRequest(method, path, strings.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + } else { + c.Request = httptest.NewRequest(method, path, nil) + } + if userID > 0 { + c.Set("user_id", userID) + } + return c, rec +} + +func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any { + t.Helper() + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + return resp +} + +// ==================== 纯函数测试: buildAsyncRequestBody ==================== + +func TestBuildAsyncRequestBody(t *testing.T) { + body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, "sora2-landscape-10s", parsed["model"]) + require.Equal(t, false, parsed["stream"]) + + msgs := parsed["messages"].([]any) + require.Len(t, msgs, 1) + msg := msgs[0].(map[string]any) + require.Equal(t, "user", msg["role"]) + require.Equal(t, "一只猫在跳舞", msg["content"]) +} + +func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) { + body := buildAsyncRequestBody("gpt-image", "", "", 1) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, "gpt-image", parsed["model"]) + msgs := parsed["messages"].([]any) + msg := msgs[0].(map[string]any) + require.Equal(t, "", msg["content"]) +} + +func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) { + body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, "https://example.com/ref.png", parsed["image_input"]) +} + +func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) { + body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, float64(3), parsed["video_count"]) +} + +func TestNormalizeVideoCount(t *testing.T) { + require.Equal(t, 1, normalizeVideoCount("video", 0)) + require.Equal(t, 2, normalizeVideoCount("video", 2)) + require.Equal(t, 3, normalizeVideoCount("video", 5)) + require.Equal(t, 1, normalizeVideoCount("image", 3)) +} + +// ==================== 纯函数测试: parseMediaURLsFromBody ==================== + +func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) { + urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`)) + require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) +} + +func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) { + urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`)) + require.Equal(t, []string{"https://a.com/video.mp4"}, urls) +} + +func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody(nil)) + require.Nil(t, parseMediaURLsFromBody([]byte{})) +} + +func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte("not json"))) +} + +func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`))) +} + +func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`))) +} + +func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`))) +} + +func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) { + body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}` + urls := parseMediaURLsFromBody([]byte(body)) + require.Len(t, urls, 2) + require.Equal(t, "https://multi.com/a.mp4", urls[0]) +} + +func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) { + urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`)) + require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) +} + +func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`))) +} + +func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) { + // media_urls 不是 string 数组 + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`))) +} + +func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`))) +} + +// ==================== 纯函数测试: extractMediaURLsFromResult ==================== + +func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) { + result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} + recorder := httptest.NewRecorder() + url, urls := extractMediaURLsFromResult(result, recorder) + require.Equal(t, "https://oauth.com/video.mp4", url) + require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls) +} + +func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) { + result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} + recorder := httptest.NewRecorder() + _, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`)) + url, urls := extractMediaURLsFromResult(result, recorder) + require.Equal(t, "https://body.com/1.mp4", url) + require.Len(t, urls, 2) +} + +func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) { + recorder := httptest.NewRecorder() + _, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`)) + url, urls := extractMediaURLsFromResult(nil, recorder) + require.Equal(t, "https://upstream.com/video.mp4", url) + require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls) +} + +func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) { + recorder := httptest.NewRecorder() + url, urls := extractMediaURLsFromResult(nil, recorder) + require.Empty(t, url) + require.Nil(t, urls) +} + +func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) { + result := &service.ForwardResult{MediaURL: ""} + recorder := httptest.NewRecorder() + url, urls := extractMediaURLsFromResult(result, recorder) + require.Empty(t, url) + require.Nil(t, urls) +} + +// ==================== getUserIDFromContext ==================== + +func TestGetUserIDFromContext_Int64(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", int64(42)) + require.Equal(t, int64(42), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_AuthSubject(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777}) + require.Equal(t, int64(777), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_Float64(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", float64(99)) + require.Equal(t, int64(99), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_String(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", "123") + require.Equal(t, int64(123), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_UserIDFallback(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("userID", int64(55)) + require.Equal(t, int64(55), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_NoID(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + require.Equal(t, int64(0), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_InvalidString(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", "not-a-number") + require.Equal(t, int64(0), getUserIDFromContext(c)) +} + +// ==================== Handler: Generate ==================== + +func TestGenerate_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0) + h.Generate(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestGenerate_BadRequest_MissingModel(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGenerate_BadRequest_MissingPrompt(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGenerate_BadRequest_InvalidJSON(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGenerate_TooManyRequests(t *testing.T) { + repo := newStubSoraGenRepo() + repo.countValue = 3 + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +func TestGenerate_CountError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.countErr = fmt.Errorf("db error") + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestGenerate_Success(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.NotZero(t, data["generation_id"]) + require.Equal(t, "pending", data["status"]) +} + +func TestGenerate_DefaultMediaType(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "video", repo.gens[1].MediaType) +} + +func TestGenerate_ImageMediaType(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "image", repo.gens[1].MediaType) +} + +func TestGenerate_CreatePendingError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.createErr = fmt.Errorf("create failed") + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestGenerate_APIKeyInContext(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + c.Set("api_key_id", int64(42)) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(42), *repo.gens[1].APIKeyID) +} + +func TestGenerate_NoAPIKeyInContext(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Nil(t, repo.gens[1].APIKeyID) +} + +func TestGenerate_ConcurrencyBoundary(t *testing.T) { + // activeCount == 2 应该允许 + repo := newStubSoraGenRepo() + repo.countValue = 2 + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +// ==================== Handler: ListGenerations ==================== + +func TestListGenerations_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0) + h.ListGenerations(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestListGenerations_Success(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"} + repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"} + repo.nextID = 3 + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1) + h.ListGenerations(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + items := data["data"].([]any) + require.Len(t, items, 2) + require.Equal(t, float64(2), data["total"]) +} + +func TestListGenerations_ListError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.listErr = fmt.Errorf("db error") + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) + h.ListGenerations(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestListGenerations_DefaultPagination(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + // 不传分页参数,应默认 page=1 page_size=20 + c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) + h.ListGenerations(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, float64(1), data["page"]) +} + +// ==================== Handler: GetGeneration ==================== + +func TestGetGeneration_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.GetGeneration(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestGetGeneration_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.GetGeneration(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGetGeneration_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.GetGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestGetGeneration_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.GetGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestGetGeneration_Success(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.GetGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, float64(1), data["id"]) +} + +// ==================== Handler: DeleteGeneration ==================== + +func TestDeleteGeneration_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestDeleteGeneration_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDeleteGeneration_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestDeleteGeneration_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestDeleteGeneration_Success(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + _, exists := repo.gens[1] + require.False(t, exists) +} + +// ==================== Handler: CancelGeneration ==================== + +func TestCancelGeneration_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestCancelGeneration_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestCancelGeneration_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestCancelGeneration_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestCancelGeneration_Pending(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "cancelled", repo.gens[1].Status) +} + +func TestCancelGeneration_Generating(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "cancelled", repo.gens[1].Status) +} + +func TestCancelGeneration_Completed(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} + +func TestCancelGeneration_Failed(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} + +func TestCancelGeneration_Cancelled(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} + +// ==================== Handler: GetQuota ==================== + +func TestGetQuota_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0) + h.GetQuota(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestGetQuota_NilQuotaService(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) + h.GetQuota(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, "unlimited", data["source"]) +} + +// ==================== Handler: GetModels ==================== + +func TestGetModels(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0) + h.GetModels(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].([]any) + require.Len(t, data, 4) + // 验证类型分布 + videoCount, imageCount := 0, 0 + for _, item := range data { + m := item.(map[string]any) + if m["type"] == "video" { + videoCount++ + } else if m["type"] == "image" { + imageCount++ + } + } + require.Equal(t, 3, videoCount) + require.Equal(t, 1, imageCount) +} + +// ==================== Handler: GetStorageStatus ==================== + +func TestGetStorageStatus_NilS3(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, false, data["s3_enabled"]) + require.Equal(t, false, data["s3_healthy"]) + require.Equal(t, false, data["local_enabled"]) +} + +func TestGetStorageStatus_LocalEnabled(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-storage-status-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, false, data["s3_enabled"]) + require.Equal(t, false, data["s3_healthy"]) + require.Equal(t, true, data["local_enabled"]) +} + +// ==================== Handler: SaveToStorage ==================== + +func TestSaveToStorage_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestSaveToStorage_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSaveToStorage_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestSaveToStorage_NotUpstream(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSaveToStorage_EmptyMediaURL(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSaveToStorage_S3Nil(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "云存储") +} + +func TestSaveToStorage_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +// ==================== storeMediaWithDegradation — nil guard 路径 ==================== + +func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) { + h := &SoraClientHandler{} + url, urls, storageType, keys, size := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Equal(t, "https://upstream.com/v.mp4", url) + require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls) + require.Nil(t, keys) + require.Equal(t, int64(0), size) +} + +func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) { + h := &SoraClientHandler{} + url, urls, storageType, keys, size := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Equal(t, "https://a.com/1.mp4", url) + require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) + require.Nil(t, keys) + require.Equal(t, int64(0), size) +} + +func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) { + h := &SoraClientHandler{} + url, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{}, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Equal(t, "https://upstream.com/v.mp4", url) +} + +// ==================== Stub: UserRepository (用于 SoraQuotaService) ==================== + +var _ service.UserRepository = (*stubUserRepoForHandler)(nil) + +type stubUserRepoForHandler struct { + users map[int64]*service.User + updateErr error +} + +func newStubUserRepoForHandler() *stubUserRepoForHandler { + return &stubUserRepoForHandler{users: make(map[int64]*service.User)} +} + +func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) { + if u, ok := r.users[id]; ok { + return u, nil + } + return nil, fmt.Errorf("user not found") +} +func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error { + if r.updateErr != nil { + return r.updateErr + } + r.users[user.ID] = user + return nil +} +func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil } +func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) { + return nil, nil +} +func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) { + return nil, nil +} +func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil } +func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil } +func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil } +func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) { + return false, nil +} +func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil } +func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil } +func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error { + return nil +} + +// ==================== NewSoraClientHandler ==================== + +func TestNewSoraClientHandler(t *testing.T) { + h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) + require.NotNil(t, h) +} + +func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) { + h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) + require.NotNil(t, h) + require.Nil(t, h.apiKeyService) +} + +// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ==================== + +var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil) + +type stubAPIKeyRepoForHandler struct { + keys map[int64]*service.APIKey + getErr error +} + +func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler { + return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)} +} + +func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) { + if r.getErr != nil { + return nil, r.getErr + } + if k, ok := r.keys[id]; ok { + return k, nil + } + return nil, fmt.Errorf("api key not found: %d", id) +} +func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil } +func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) { + return "", 0, nil +} +func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil } +func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil } +func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) { + return false, nil +} +func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error { + return nil +} + +// newTestAPIKeyService 创建测试用的 APIKeyService +func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService { + return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{}) +} + +// ==================== Generate: API Key 校验(前端传递 api_key_id)==================== + +func TestGenerate_WithAPIKeyID_Success(t *testing.T) { + // 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + groupID := int64(5) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyActive, + GroupID: &groupID, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.NotZero(t, data["generation_id"]) + + // 验证 api_key_id 已关联到生成记录 + gen := repo.gens[1] + require.NotNil(t, gen.APIKeyID) + require.Equal(t, int64(42), *gen.APIKeyID) +} + +func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) { + // 前端传递不存在的 api_key_id → 400 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "不存在") +} + +func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) { + // 前端传递别人的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 999, // 属于 user 999 + Status: service.StatusAPIKeyActive, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "不属于") +} + +func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) { + // 前端传递已禁用的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyDisabled, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "不可用") +} + +func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) { + // 前端传递配额耗尽的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyQuotaExhausted, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestGenerate_WithAPIKeyID_Expired(t *testing.T) { + // 前端传递已过期的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyExpired, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) { + // apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + h := &SoraClientHandler{genService: genService} // apiKeyService = nil + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + // apiKeyService 为 nil → 跳过校验 → api_key_id 不记录 + require.Nil(t, repo.gens[1].APIKeyID) +} + +func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) { + // api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyActive, + GroupID: nil, // 无分组 + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(42), *repo.gens[1].APIKeyID) +} + +func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) { + // 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Nil(t, repo.gens[1].APIKeyID) +} + +func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) { + // 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + groupID := int64(10) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyActive, + GroupID: &groupID, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + // 应使用 body 中的 api_key_id=42,而不是 context 中的 99 + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(42), *repo.gens[1].APIKeyID) +} + +func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) { + // 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由) + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + c.Set("api_key_id", int64(99)) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + // 应使用 context 中的 api_key_id=99 + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(99), *repo.gens[1].APIKeyID) +} + +func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) { + // JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + // JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验 + // api_key_id=0 不存在 → 400 + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +// ==================== processGeneration: groupID 传递与 ForcePlatform ==================== + +func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) { + // groupID 不为 nil → 不设置 ForcePlatform + // gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + gid := int64(5) + h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") +} + +func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) { + // groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") +} + +func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) { + // 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled + require.Equal(t, "cancelled", repo.gens[1].Status) +} + +// ==================== GenerateRequest JSON 解析 ==================== + +func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) { + // 验证 api_key_id 在 JSON 中正确解析为 *int64 + var req GenerateRequest + err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req) + require.NoError(t, err) + require.NotNil(t, req.APIKeyID) + require.Equal(t, int64(42), *req.APIKeyID) +} + +func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) { + // 不传 api_key_id → 解析后为 nil + var req GenerateRequest + err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req) + require.NoError(t, err) + require.Nil(t, req.APIKeyID) +} + +func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) { + // api_key_id: null → 解析后为 nil + var req GenerateRequest + err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req) + require.NoError(t, err) + require.Nil(t, req.APIKeyID) +} + +func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) { + // 全字段解析 + var req GenerateRequest + err := json.Unmarshal([]byte(`{ + "model":"sora2-landscape-10s", + "prompt":"test prompt", + "media_type":"video", + "video_count":2, + "image_input":"data:image/png;base64,abc", + "api_key_id":100 + }`), &req) + require.NoError(t, err) + require.Equal(t, "sora2-landscape-10s", req.Model) + require.Equal(t, "test prompt", req.Prompt) + require.Equal(t, "video", req.MediaType) + require.Equal(t, 2, req.VideoCount) + require.Equal(t, "data:image/png;base64,abc", req.ImageInput) + require.NotNil(t, req.APIKeyID) + require.Equal(t, int64(100), *req.APIKeyID) +} + +func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) { + // api_key_id 为 nil 时 JSON 序列化应省略 + req := GenerateRequest{Model: "sora2", Prompt: "test"} + b, err := json.Marshal(req) + require.NoError(t, err) + var parsed map[string]any + require.NoError(t, json.Unmarshal(b, &parsed)) + _, hasAPIKeyID := parsed["api_key_id"] + require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略") +} + +func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) { + // api_key_id 不为 nil 时 JSON 序列化应包含 + id := int64(42) + req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id} + b, err := json.Marshal(req) + require.NoError(t, err) + var parsed map[string]any + require.NoError(t, json.Unmarshal(b, &parsed)) + require.Equal(t, float64(42), parsed["api_key_id"]) +} + +// ==================== GetQuota: 有配额服务 ==================== + +func TestGetQuota_WithQuotaService_Success(t *testing.T) { + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, + SoraStorageUsedBytes: 3 * 1024 * 1024, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) + h.GetQuota(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, "user", data["source"]) + require.Equal(t, float64(10*1024*1024), data["quota_bytes"]) + require.Equal(t, float64(3*1024*1024), data["used_bytes"]) +} + +func TestGetQuota_WithQuotaService_Error(t *testing.T) { + // 用户不存在时 GetQuota 返回错误 + userRepo := newStubUserRepoForHandler() + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999) + h.GetQuota(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== Generate: 配额检查 ==================== + +func TestGenerate_QuotaCheckFailed(t *testing.T) { + // 配额超限时返回 429 + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 1024, + SoraStorageUsedBytes: 1025, // 已超限 + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +func TestGenerate_QuotaCheckPassed(t *testing.T) { + // 配额充足时允许生成 + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, + SoraStorageUsedBytes: 0, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +// ==================== Stub: SettingRepository (用于 S3 存储测试) ==================== + +var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil) + +type stubSettingRepoForHandler struct { + values map[string]string +} + +func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler { + if values == nil { + values = make(map[string]string) + } + return &stubSettingRepoForHandler{values: values} +} + +func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) { + if v, ok := r.values[key]; ok { + return &service.Setting{Key: key, Value: v}, nil + } + return nil, service.ErrSettingNotFound +} +func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) { + if v, ok := r.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} +func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error { + r.values[key] = value + return nil +} +func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string) + for _, k := range keys { + if v, ok := r.values[k]; ok { + result[k] = v + } + } + return result, nil +} +func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error { + for k, v := range settings { + r.values[k] = v + } + return nil +} +func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) { + return r.values, nil +} +func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error { + delete(r.values, key) + return nil +} + +// ==================== S3 / MediaStorage 辅助函数 ==================== + +// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。 +func newS3StorageForHandler(endpoint string) *service.SoraS3Storage { + settingRepo := newStubSettingRepoForHandler(map[string]string{ + "sora_s3_enabled": "true", + "sora_s3_endpoint": endpoint, + "sora_s3_region": "us-east-1", + "sora_s3_bucket": "test-bucket", + "sora_s3_access_key_id": "AKIATEST", + "sora_s3_secret_access_key": "test-secret", + "sora_s3_prefix": "sora", + "sora_s3_force_path_style": "true", + }) + settingService := service.NewSettingService(settingRepo, &config.Config{}) + return service.NewSoraS3Storage(settingService) +} + +// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。 +func newFakeSourceServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "video/mp4") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("fake video data for test")) + })) +} + +// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。 +// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。 +func newFakeS3Server(mode string) *httptest.Server { + var counter atomic.Int32 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.Copy(io.Discard, r.Body) + _ = r.Body.Close() + + switch mode { + case "ok": + w.Header().Set("ETag", `"test-etag"`) + w.WriteHeader(http.StatusOK) + case "fail": + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`AccessDenied`)) + case "fail-second": + n := counter.Add(1) + if n <= 1 { + w.Header().Set("ETag", `"test-etag"`) + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`AccessDenied`)) + } + } + })) +} + +// ==================== processGeneration 直接调用测试 ==================== + +func TestProcessGeneration_MarkGeneratingFails(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + repo.updateErr = fmt.Errorf("db error") + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + // 直接调用(非 goroutine),MarkGenerating 失败 → 早退 + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating" + // repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed + // 因此 ErrorMessage 为空(证明未调用 MarkFailed) + require.Equal(t, "generating", repo.gens[1].Status) + require.Empty(t, repo.gens[1].ErrorMessage) +} + +func TestProcessGeneration_GatewayServiceNil(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + // gatewayService 未设置 → MarkFailed + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") +} + +// ==================== storeMediaWithDegradation: S3 路径 ==================== + +func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeS3, storageType) + require.Len(t, s3Keys, 1) + require.NotEmpty(t, s3Keys[0]) + require.Len(t, storedURLs, 1) + require.Equal(t, storedURL, storedURLs[0]) + require.Contains(t, storedURL, fakeS3.URL) + require.Contains(t, storedURL, "/test-bucket/") + require.Greater(t, fileSize, int64(0)) +} + +func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} + storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, + ) + require.Equal(t, service.SoraStorageTypeS3, storageType) + require.Len(t, s3Keys, 2) + require.Len(t, storedURLs, 2) + require.Equal(t, storedURL, storedURLs[0]) + require.Contains(t, storedURLs[0], fakeS3.URL) + require.Contains(t, storedURLs[1], fakeS3.URL) + require.Greater(t, fileSize, int64(0)) +} + +func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) { + // 上游返回 404 → 下载失败 → S3 上传不会开始 + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer badSource.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + _, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) +} + +func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + // S3 失败,降级到 upstream + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Nil(t, s3Keys) +} + +func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail-second") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} + _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, + ) + // 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Nil(t, s3Keys) +} + +// ==================== storeMediaWithDegradation: 本地存储路径 ==================== + +func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) { + // 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: "/dev/null/invalid_dir", + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + _, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, + ) + // 本地存储失败,降级到 upstream + require.Equal(t, service.SoraStorageTypeUpstream, storageType) +} + +func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + DownloadTimeoutSeconds: 5, + MaxDownloadBytes: 10 * 1024 * 1024, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeLocal, storageType) + require.Nil(t, s3Keys) // 本地存储不返回 S3 keys +} + +func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + DownloadTimeoutSeconds: 5, + MaxDownloadBytes: 10 * 1024 * 1024, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{ + s3Storage: s3Storage, + mediaStorage: mediaStorage, + } + + _, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + // S3 失败 → 本地存储成功 + require.Equal(t, service.SoraStorageTypeLocal, storageType) +} + +// ==================== SaveToStorage: S3 路径 ==================== + +func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, resp["message"], "S3") +} + +func TestSaveToStorage_UpstreamURLExpired(t *testing.T) { + expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer expiredServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: expiredServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusGone, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "过期") +} + +func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Contains(t, data["message"], "S3") + require.NotEmpty(t, data["object_key"]) + // 验证记录已更新为 S3 存储 + require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) +} + +func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v1.mp4", + MediaURLs: []string{ + sourceServer.URL + "/v1.mp4", + sourceServer.URL + "/v2.mp4", + }, + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Len(t, data["object_keys"].([]any), 2) + require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) + require.Len(t, repo.gens[1].S3ObjectKeys, 2) + require.Len(t, repo.gens[1].MediaURLs, 2) +} + +func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 100 * 1024 * 1024, + SoraStorageUsedBytes: 0, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusOK, rec.Code) + // 验证配额已累加 + require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) +} + +func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败 + repo.updateErr = fmt.Errorf("db error") + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== GetStorageStatus: S3 路径 ==================== + +func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) { + // S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket) + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, true, data["s3_enabled"]) + require.Equal(t, false, data["s3_healthy"]) +} + +func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) { + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, true, data["s3_enabled"]) + require.Equal(t, true, data["s3_healthy"]) +} + +// ==================== Stub: AccountRepository (用于 GatewayService) ==================== + +var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil) + +type stubAccountRepoForHandler struct { + accounts []service.Account +} + +func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil } +func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) { + for i := range r.accounts { + if r.accounts[i].ID == id { + return &r.accounts[i], nil + } + } + return nil, fmt.Errorf("account not found") +} +func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) { + return false, nil +} +func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil } +func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil } +func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error { + return nil +} +func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { + return 0, nil +} +func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil } +func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error { + return nil +} +func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error { + return nil +} +func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error { + return nil +} +func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error { + return nil +} +func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) { + return 0, nil +} + +// ==================== Stub: SoraClient (用于 SoraGatewayService) ==================== + +var _ service.SoraClient = (*stubSoraClientForHandler)(nil) + +type stubSoraClientForHandler struct { + videoStatus *service.SoraVideoTaskStatus +} + +func (s *stubSoraClientForHandler) Enabled() bool { return true } +func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) { + return "task-image", nil +} +func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) { + return nil, nil +} +func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) { + return nil, nil +} +func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error { + return nil +} +func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error { + return nil +} +func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error { + return nil +} +func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) { + return nil, nil +} +func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) { + return s.videoStatus, nil +} + +// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ==================== + +// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。 +func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { + return service.NewGatewayService( + accountRepo, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + ) +} + +// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。 +func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + return service.NewSoraGatewayService(soraClient, nil, nil, cfg) +} + +// ==================== processGeneration: 更多路径测试 ==================== + +func TestProcessGeneration_SelectAccountError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + // accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts" + accountRepo := &stubAccountRepoForHandler{accounts: nil} + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") +} + +func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + // 提供可用账号使 SelectAccountForModel 成功 + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + // soraGatewayService 为 nil + h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService") +} + +func TestProcessGeneration_ForwardError(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + // SoraClient 返回视频任务失败 + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "failed", + ErrorMsg: "content policy violation", + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "生成失败") +} + +func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + // MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration + // 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。 + repo.getByIDOverrideAfterN = 1 + repo.getByIDOverrideStatus = "cancelled" + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"}, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating) + require.Equal(t, "generating", repo.gens[1].Status) +} + +func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + // SoraClient 返回 completed 但无 URL + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: nil, // 无 URL + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL") +} + +func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + // MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次) + // 第 2 次返回 "cancelled" 状态,模拟外部取消 + repo.getByIDOverrideAfterN = 1 + repo.getByIDOverrideStatus = "cancelled" + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating) + require.Equal(t, "generating", repo.gens[1].Status) +} + +func TestProcessGeneration_FullSuccessUpstream(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + // 无 S3 和本地存储,降级到 upstream + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + require.Equal(t, "completed", repo.gens[1].Status) + require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType) + require.NotEmpty(t, repo.gens[1].MediaURL) +} + +func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{sourceServer.URL + "/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + s3Storage := newS3StorageForHandler(fakeS3.URL) + + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + s3Storage: s3Storage, + quotaService: quotaService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + require.Equal(t, "completed", repo.gens[1].Status) + require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) + require.NotEmpty(t, repo.gens[1].S3ObjectKeys) + require.Greater(t, repo.gens[1].FileSizeBytes, int64(0)) + // 验证配额已累加 + require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) +} + +func TestProcessGeneration_MarkCompletedFails(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败 + repo.updateCallCount = new(int32) + repo.updateFailAfterN = 1 + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + // MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。 + // 由于 stub 存储的是指针,内存中的状态已被修改为 completed。 + // 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。 + require.Equal(t, "completed", repo.gens[1].Status) +} + +// ==================== cleanupStoredMedia 直接测试 ==================== + +func TestCleanupStoredMedia_S3Path(t *testing.T) { + // S3 清理路径:s3Storage 为 nil 时不 panic + h := &SoraClientHandler{} + // 不应 panic + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) +} + +func TestCleanupStoredMedia_LocalPath(t *testing.T) { + // 本地清理路径:mediaStorage 为 nil 时不 panic + h := &SoraClientHandler{} + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"}) +} + +func TestCleanupStoredMedia_UpstreamPath(t *testing.T) { + // upstream 类型不清理 + h := &SoraClientHandler{} + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil) +} + +func TestCleanupStoredMedia_EmptyKeys(t *testing.T) { + // 空 keys 不触发清理 + h := &SoraClientHandler{} + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil) + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil) +} + +// ==================== DeleteGeneration: 本地存储清理路径 ==================== + +func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-delete-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, + UserID: 1, + Status: "completed", + StorageType: service.SoraStorageTypeLocal, + MediaURL: "video/test.mp4", + MediaURLs: []string{"video/test.mp4"}, + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + _, exists := repo.gens[1] + require.False(t, exists) +} + +func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) { + // MediaURLs 为空,使用 MediaURL 作为清理路径 + tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, + UserID: 1, + Status: "completed", + StorageType: service.SoraStorageTypeLocal, + MediaURL: "video/test.mp4", + MediaURLs: nil, // 空 + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) { + // 非本地存储类型 → 跳过清理 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, + UserID: 1, + Status: "completed", + StorageType: service.SoraStorageTypeUpstream, + MediaURL: "https://upstream.com/v.mp4", + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestDeleteGeneration_DeleteError(t *testing.T) { + // repo.Delete 出错 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"} + repo.deleteErr = fmt.Errorf("delete failed") + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +// ==================== fetchUpstreamModels 测试 ==================== + +func TestFetchUpstreamModels_NilGateway(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + h := &SoraClientHandler{} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "gatewayService 未初始化") +} + +func TestFetchUpstreamModels_NoAccounts(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + accountRepo := &stubAccountRepoForHandler{accounts: nil} + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "选择 Sora 账号失败") +} + +func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "不支持模型同步") +} + +func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"base_url": "https://sora.test"}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "api_key") +} + +func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com" + // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败 + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test"}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) +} + +func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "状态码 500") +} + +func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("not json")) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "解析响应失败") +} + +func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "空模型列表") +} + +func TestFetchUpstreamModels_Success(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求头 + require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization")) + require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models")) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + families, err := h.fetchUpstreamModels(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, families) +} + +func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "未能从上游模型列表中识别") +} + +// ==================== getModelFamilies 缓存测试 ==================== + +func TestGetModelFamilies_CachesLocalConfig(t *testing.T) { + // gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置 + h := &SoraClientHandler{} + families := h.getModelFamilies(context.Background()) + require.NotEmpty(t, families) + + // 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL) + families2 := h.getModelFamilies(context.Background()) + require.Equal(t, families, families2) + require.False(t, h.modelCacheUpstream) +} + +func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) { + t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + + families := h.getModelFamilies(context.Background()) + require.NotEmpty(t, families) + require.True(t, h.modelCacheUpstream) + + // 第二次调用命中缓存 + families2 := h.getModelFamilies(context.Background()) + require.Equal(t, families, families2) +} + +func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) { + // 预设过期的缓存(modelCacheUpstream=false → 短 TTL) + h := &SoraClientHandler{ + cachedFamilies: []service.SoraModelFamily{{ID: "old"}}, + modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期 + modelCacheUpstream: false, + } + // gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存 + families := h.getModelFamilies(context.Background()) + require.NotEmpty(t, families) + // 缓存已刷新,不再是 "old" + found := false + for _, f := range families { + if f.ID == "old" { + found = true + } + } + require.False(t, found, "过期缓存应被刷新") +} + +// ==================== processGeneration: groupID 与 ForcePlatform ==================== + +func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) { + // groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 空账号列表 → SelectAccountForModel 失败 + accountRepo := &stubAccountRepoForHandler{accounts: nil} + gatewayService := newMinimalGatewayService(accountRepo) + + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") +} + +// ==================== Generate: 配额检查非 QuotaExceeded 错误 ==================== + +func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) { + // quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error + userRepo := newStubUserRepoForHandler() + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil) + + body := `{"model":"sora2-landscape-10s","prompt":"test"}` + c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +// ==================== Generate: CreatePending 并发限制错误 ==================== + +// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口 +type stubSoraGenRepoWithAtomicCreate struct { + stubSoraGenRepo + limitErr error +} + +func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error { + if r.limitErr != nil { + return r.limitErr + } + return r.stubSoraGenRepo.Create(context.Background(), gen) +} + +func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) { + repo := &stubSoraGenRepoWithAtomicCreate{ + stubSoraGenRepo: *newStubSoraGenRepo(), + limitErr: service.ErrSoraGenerationConcurrencyLimit, + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil) + + body := `{"model":"sora2-landscape-10s","prompt":"test"}` + c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) + h.Generate(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, resp["message"], "3") +} + +// ==================== SaveToStorage: 配额超限 ==================== + +func TestSaveToStorage_QuotaExceeded(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 用户配额已满 + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 10, + SoraStorageUsedBytes: 10, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ==================== + +func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error + userRepo := newStubUserRepoForHandler() + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== SaveToStorage: MediaURLs 全为空 ==================== + +func TestSaveToStorage_EmptyMediaURLs(t *testing.T) { + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: "", + MediaURLs: []string{}, + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, resp["message"], "已过期") +} + +// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ==================== + +func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail-second") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v1.mp4", + MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"}, + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ==================== + +func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + repo.updateErr = fmt.Errorf("db error") + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 100 * 1024 * 1024, + SoraStorageUsedBytes: 0, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== cleanupStoredMedia: 实际 S3 删除路径 ==================== + +func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) { + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil) +} + +func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) { + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) +} + +func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"}) +} + +// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ==================== + +func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-del-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: service.SoraStorageTypeLocal, + MediaURL: "nonexistent/video.mp4", + MediaURLs: []string{"nonexistent/video.mp4"}, + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +// ==================== CancelGeneration: 任务已结束冲突 ==================== + +func TestCancelGeneration_AlreadyCompleted(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go new file mode 100644 index 00000000..48c1e451 --- /dev/null +++ b/backend/internal/handler/sora_gateway_handler.go @@ -0,0 +1,685 @@ +package handler + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +// SoraGatewayHandler handles Sora chat completions requests +type SoraGatewayHandler struct { + gatewayService *service.GatewayService + soraGatewayService *service.SoraGatewayService + billingCacheService *service.BillingCacheService + usageRecordWorkerPool *service.UsageRecordWorkerPool + concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int + streamMode string + soraTLSEnabled bool + soraMediaSigningKey string + soraMediaRoot string +} + +// NewSoraGatewayHandler creates a new SoraGatewayHandler +func NewSoraGatewayHandler( + gatewayService *service.GatewayService, + soraGatewayService *service.SoraGatewayService, + concurrencyService *service.ConcurrencyService, + billingCacheService *service.BillingCacheService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, + cfg *config.Config, +) *SoraGatewayHandler { + pingInterval := time.Duration(0) + maxAccountSwitches := 3 + streamMode := "force" + soraTLSEnabled := true + signKey := "" + mediaRoot := "/app/data/sora" + if cfg != nil { + pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + if cfg.Gateway.MaxAccountSwitches > 0 { + maxAccountSwitches = cfg.Gateway.MaxAccountSwitches + } + if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" { + streamMode = mode + } + soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint + signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey) + if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" { + mediaRoot = root + } + } + return &SoraGatewayHandler{ + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + billingCacheService: billingCacheService, + usageRecordWorkerPool: usageRecordWorkerPool, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + maxAccountSwitches: maxAccountSwitches, + streamMode: strings.ToLower(streamMode), + soraTLSEnabled: soraTLSEnabled, + soraMediaSigningKey: signKey, + soraMediaRoot: mediaRoot, + } +} + +// ChatCompletions handles Sora /v1/chat/completions endpoint +func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.sora_gateway.chat_completions", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + setOpsRequestContext(c, "", false, body) + + // 校验请求体 JSON 合法性 + if !gjson.ValidBytes(body) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + + msgsResult := gjson.GetBytes(body, "messages") + if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required") + return + } + + clientStream := gjson.GetBytes(body, "stream").Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream)) + if !clientStream { + if h.streamMode == "error" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true") + return + } + var err error + body, err = sjson.SetBytes(body, "stream", true) + if err != nil { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return + } + } + + setOpsRequestContext(c, reqModel, clientStream, body) + + platform := "" + if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { + platform = forced + } else if apiKey.Group != nil { + platform = apiKey.Group.Platform + } + if platform != service.PlatformSora { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform") + return + } + + streamStarted := false + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false + if err != nil { + reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err)) + } else if !canWait { + reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait)) + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted) + if err != nil { + reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + + sessionHash := generateOpenAISessionHash(c, body) + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 + var lastFailoverBody []byte + var lastFailoverHeaders http.Header + + for { + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") + if err != nil { + reqLog.Warn("sora.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + return + } + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) + fields := []zap.Field{ + zap.Int("last_upstream_status", lastFailoverStatus), + } + if rayID != "" { + fields = append(fields, zap.String("last_upstream_cf_ray", rayID)) + } + if mitigated != "" { + fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated)) + } + if contentType != "" { + fields = append(fields, zap.String("last_upstream_content_type", contentType)) + } + reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...) + h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) + return + } + account := selection.Account + setOpsSelectedAccount(c, account.ID, account.Platform) + proxyBound := account.ProxyID != nil + proxyID := int64(0) + if account.ProxyID != nil { + proxyID = *account.ProxyID + } + tlsFingerprintEnabled := h.soraTLSEnabled + + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + accountWaitCounted := false + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + reqLog.Warn("sora.account_wait_counter_increment_failed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Error(err), + ) + } else if !canWait { + reqLog.Info("sora.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + }() + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + clientStream, + &streamStarted, + ) + if err != nil { + reqLog.Warn("sora.account_slot_acquire_failed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Error(err), + ) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } + } + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream) + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + if switchCount >= maxAccountSwitches { + lastFailoverStatus = failoverErr.StatusCode + lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) + lastFailoverBody = failoverErr.ResponseBody + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + } + if rayID != "" { + fields = append(fields, zap.String("upstream_cf_ray", rayID)) + } + if mitigated != "" { + fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) + } + if contentType != "" { + fields = append(fields, zap.String("upstream_content_type", contentType)) + } + reqLog.Warn("sora.upstream_failover_exhausted", fields...) + h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) + return + } + lastFailoverStatus = failoverErr.StatusCode + lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) + lastFailoverBody = failoverErr.ResponseBody + switchCount++ + upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody) + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.String("upstream_error_code", upstreamErrCode), + zap.String("upstream_error_message", upstreamErrMsg), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + } + if rayID != "" { + fields = append(fields, zap.String("upstream_cf_ray", rayID)) + } + if mitigated != "" { + fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) + } + if contentType != "" { + fields = append(fields, zap.String("upstream_content_type", contentType)) + } + reqLog.Warn("sora.upstream_failover_switching", fields...) + continue + } + reqLog.Error("sora.forward_failed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Error(err), + ) + return + } + + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + }); err != nil { + logger.L().With( + zap.String("component", "handler.sora_gateway.chat_completions"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("sora.record_usage_failed", zap.Error(err)) + } + }) + reqLog.Debug("sora.request_completed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("switch_count", switchCount), + ) + return + } +} + +func generateOpenAISessionHash(c *gin.Context, body []byte) string { + if c == nil { + return "" + } + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + } + if sessionID == "" { + return "" + } + hash := sha256.Sum256([]byte(sessionID)) + return hex.EncodeToString(hash[:]) +} + +func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "handler.sora_gateway.chat_completions"), + zap.Any("panic", recovered), + ).Error("sora.usage_record_task_panic_recovered") + } + }() + task(ctx) +} + +func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", + fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) +} + +func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) { + status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) { + if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) { + baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode) + return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) + } + + upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody) + if strings.EqualFold(upstreamCode, "cf_shield_429") { + baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry." + return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) + } + if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) { + switch statusCode { + case 401, 403, 404, 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", upstreamMessage + case 429: + return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage + } + } + + switch statusCode { + case 401: + return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" + case 403: + return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" + case 404: + if strings.EqualFold(upstreamCode, "unsupported_country_code") { + return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator" + } + return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator" + case 429: + return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" + case 529: + return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" + default: + return http.StatusBadGateway, "upstream_error", "Upstream request failed" + } +} + +func cloneHTTPHeaders(headers http.Header) http.Header { + if headers == nil { + return nil + } + return headers.Clone() +} + +func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) { + if headers != nil { + mitigated = strings.TrimSpace(headers.Get("cf-mitigated")) + contentType = strings.TrimSpace(headers.Get("content-type")) + if contentType == "" { + contentType = strings.TrimSpace(headers.Get("Content-Type")) + } + } + rayID = soraerror.ExtractCloudflareRayID(headers, body) + return rayID, mitigated, contentType +} + +func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { + return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) +} + +func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool { + message = strings.TrimSpace(message) + if message == "" { + return false + } + if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests { + lower := strings.ToLower(message) + if strings.Contains(lower, "Just a moment...`) + + h := &SoraGatewayHandler{} + h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true) + + lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "upstream_error", errorObj["type"]) + msg, _ := errorObj["message"].(string) + require.Contains(t, msg, "Cloudflare challenge") + require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA") +} + +func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + headers := http.Header{} + headers.Set("cf-ray", "9d03b68c086027a1-SEA") + body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`) + + h := &SoraGatewayHandler{} + h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true) + + lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "rate_limit_error", errorObj["type"]) + msg, _ := errorObj["message"].(string) + require.Contains(t, msg, "Cloudflare shield") + require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA") +} + +func TestExtractSoraFailoverHeaderInsights(t *testing.T) { + headers := http.Header{} + headers.Set("cf-mitigated", "challenge") + headers.Set("content-type", "text/html") + body := []byte(``) + + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body) + require.Equal(t, "9cff2d62d83bb98d", rayID) + require.Equal(t, "challenge", mitigated) + require.Equal(t, "text/html", contentType) +} diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index 129dbfa6..2bd0e0d7 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -2,6 +2,7 @@ package handler import ( "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -65,8 +66,17 @@ func (h *UsageHandler) List(c *gin.Context) { // Parse additional filters model := c.Query("model") + var requestType *int16 var stream *bool - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") @@ -114,6 +124,7 @@ func (h *UsageHandler) List(c *gin.Context) { UserID: subject.UserID, // Always filter by current user for security APIKeyID: apiKeyID, Model: model, + RequestType: requestType, Stream: stream, BillingType: billingType, StartTime: startTime, @@ -392,7 +403,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) { return } - stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs) + stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{}) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/usage_handler_request_type_test.go b/backend/internal/handler/usage_handler_request_type_test.go new file mode 100644 index 00000000..7c4c7913 --- /dev/null +++ b/backend/internal/handler/usage_handler_request_type_test.go @@ -0,0 +1,80 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userUsageRepoCapture struct { + service.UsageLogRepository + listFilters usagestats.UsageLogFilters +} + +func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + s.listFilters = filters + return []service.UsageLog{}, &pagination.PaginationResult{ + Total: 0, + Page: params.Page, + PageSize: params.PageSize, + Pages: 0, + }, nil +} + +func newUserUsageRequestTypeTestRouter(repo *userUsageRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + usageSvc := service.NewUsageService(repo, nil, nil, nil) + handler := NewUsageHandler(usageSvc, nil) + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42}) + c.Next() + }) + router.GET("/usage", handler.List) + return router +} + +func TestUserUsageListRequestTypePriority(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?request_type=ws_v2&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, int64(42), repo.listFilters.UserID) + require.NotNil(t, repo.listFilters.RequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType) + require.Nil(t, repo.listFilters.Stream) +} + +func TestUserUsageListInvalidRequestType(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?request_type=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestUserUsageListInvalidStream(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?stream=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go new file mode 100644 index 00000000..c7c48e14 --- /dev/null +++ b/backend/internal/handler/usage_record_submit_task_test.go @@ -0,0 +1,184 @@ +package handler + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func newUsageRecordTestPool(t *testing.T) *service.UsageRecordWorkerPool { + t.Helper() + pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 8, + TaskTimeout: time.Second, + OverflowPolicy: "drop", + OverflowSamplePercent: 0, + AutoScaleEnabled: false, + }) + t.Cleanup(pool.Stop) + return pool +} + +func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { + pool := newUsageRecordTestPool(t) + h := &GatewayHandler{usageRecordWorkerPool: pool} + + done := make(chan struct{}) + h.submitUsageRecordTask(func(ctx context.Context) { + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } +} + +func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { + h := &GatewayHandler{} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load()) +} + +func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { + h := &GatewayHandler{} + require.NotPanics(t, func() { + h.submitUsageRecordTask(nil) + }) +} + +func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { + h := &GatewayHandler{} + var called atomic.Bool + + require.NotPanics(t, func() { + h.submitUsageRecordTask(func(ctx context.Context) { + panic("usage task panic") + }) + }) + + h.submitUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + require.True(t, called.Load(), "panic 后后续任务应仍可执行") +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { + pool := newUsageRecordTestPool(t) + h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} + + done := make(chan struct{}) + h.submitUsageRecordTask(func(ctx context.Context) { + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { + h := &OpenAIGatewayHandler{} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load()) +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { + h := &OpenAIGatewayHandler{} + require.NotPanics(t, func() { + h.submitUsageRecordTask(nil) + }) +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { + h := &OpenAIGatewayHandler{} + var called atomic.Bool + + require.NotPanics(t, func() { + h.submitUsageRecordTask(func(ctx context.Context) { + panic("usage task panic") + }) + }) + + h.submitUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + require.True(t, called.Load(), "panic 后后续任务应仍可执行") +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { + pool := newUsageRecordTestPool(t) + h := &SoraGatewayHandler{usageRecordWorkerPool: pool} + + done := make(chan struct{}) + h.submitUsageRecordTask(func(ctx context.Context) { + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { + h := &SoraGatewayHandler{} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load()) +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { + h := &SoraGatewayHandler{} + require.NotPanics(t, func() { + h.submitUsageRecordTask(nil) + }) +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { + h := &SoraGatewayHandler{} + var called atomic.Bool + + require.NotPanics(t, func() { + h.submitUsageRecordTask(func(ctx context.Context) { + panic("usage task panic") + }) + }) + + h.submitUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + require.True(t, called.Load(), "panic 后后续任务应仍可执行") +} diff --git a/backend/internal/handler/user_msg_queue_helper.go b/backend/internal/handler/user_msg_queue_helper.go new file mode 100644 index 00000000..50449b13 --- /dev/null +++ b/backend/internal/handler/user_msg_queue_helper.go @@ -0,0 +1,237 @@ +package handler + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// UserMsgQueueHelper 用户消息串行队列 Handler 层辅助 +// 复用 ConcurrencyHelper 的退避 + SSE ping 模式 +type UserMsgQueueHelper struct { + queueService *service.UserMessageQueueService + pingFormat SSEPingFormat + pingInterval time.Duration +} + +// NewUserMsgQueueHelper 创建用户消息串行队列辅助 +func NewUserMsgQueueHelper( + queueService *service.UserMessageQueueService, + pingFormat SSEPingFormat, + pingInterval time.Duration, +) *UserMsgQueueHelper { + if pingInterval <= 0 { + pingInterval = defaultPingInterval + } + return &UserMsgQueueHelper{ + queueService: queueService, + pingFormat: pingFormat, + pingInterval: pingInterval, + } +} + +// AcquireWithWait 等待获取串行锁,流式请求期间发送 SSE ping +// 返回的 releaseFunc 内部使用 sync.Once,确保只执行一次释放 +func (h *UserMsgQueueHelper) AcquireWithWait( + c *gin.Context, + accountID int64, + baseRPM int, + isStream bool, + streamStarted *bool, + timeout time.Duration, + reqLog *zap.Logger, +) (releaseFunc func(), err error) { + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + // 先尝试立即获取 + result, err := h.queueService.TryAcquire(ctx, accountID) + if err != nil { + return nil, err // fail-open 已在 service 层处理 + } + + if result.Acquired { + // 获取成功,执行 RPM 自适应延迟 + if err := h.queueService.EnforceDelay(ctx, accountID, baseRPM); err != nil { + if ctx.Err() != nil { + // 延迟期间 context 取消,释放锁 + bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second) + _ = h.queueService.Release(bgCtx, accountID, result.RequestID) + bgCancel() + return nil, ctx.Err() + } + } + reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID)) + return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil + } + + // 需要等待:指数退避轮询 + return h.waitForLockWithPing(c, ctx, accountID, baseRPM, isStream, streamStarted, reqLog) +} + +// waitForLockWithPing 等待获取锁,流式请求期间发送 SSE ping +func (h *UserMsgQueueHelper) waitForLockWithPing( + c *gin.Context, + ctx context.Context, + accountID int64, + baseRPM int, + isStream bool, + streamStarted *bool, + reqLog *zap.Logger, +) (func(), error) { + needPing := isStream && h.pingFormat != "" + + var flusher http.Flusher + if needPing { + var ok bool + flusher, ok = c.Writer.(http.Flusher) + if !ok { + needPing = false + } + } + + var pingCh <-chan time.Time + if needPing { + pingTicker := time.NewTicker(h.pingInterval) + defer pingTicker.Stop() + pingCh = pingTicker.C + } + + backoff := initialBackoff + timer := time.NewTimer(backoff) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("umq wait timeout for account %d", accountID) + + case <-pingCh: + if !*streamStarted { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + *streamStarted = true + } + if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil { + return nil, err + } + flusher.Flush() + + case <-timer.C: + result, err := h.queueService.TryAcquire(ctx, accountID) + if err != nil { + return nil, err + } + if result.Acquired { + // 获取成功,执行 RPM 自适应延迟 + if delayErr := h.queueService.EnforceDelay(ctx, accountID, baseRPM); delayErr != nil { + if ctx.Err() != nil { + bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second) + _ = h.queueService.Release(bgCtx, accountID, result.RequestID) + bgCancel() + return nil, ctx.Err() + } + } + reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID)) + return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil + } + backoff = nextBackoff(backoff) + timer.Reset(backoff) + } + } +} + +// makeReleaseFunc 创建锁释放函数(使用 sync.Once 确保只执行一次) +func (h *UserMsgQueueHelper) makeReleaseFunc(accountID int64, requestID string, reqLog *zap.Logger) func() { + var once sync.Once + return func() { + once.Do(func() { + bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer bgCancel() + if err := h.queueService.Release(bgCtx, accountID, requestID); err != nil { + reqLog.Warn("gateway.umq_release_failed", + zap.Int64("account_id", accountID), + zap.Error(err), + ) + } else { + reqLog.Debug("gateway.umq_lock_released", zap.Int64("account_id", accountID)) + } + }) + } +} + +// ThrottleWithPing 软性限速模式:施加 RPM 自适应延迟,流式期间发送 SSE ping +// 不获取串行锁,不阻塞并发。返回后即可转发请求。 +func (h *UserMsgQueueHelper) ThrottleWithPing( + c *gin.Context, + accountID int64, + baseRPM int, + isStream bool, + streamStarted *bool, + timeout time.Duration, + reqLog *zap.Logger, +) error { + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + delay := h.queueService.CalculateRPMAwareDelay(ctx, accountID, baseRPM) + if delay <= 0 { + return nil + } + + reqLog.Debug("gateway.umq_throttle_delay", + zap.Int64("account_id", accountID), + zap.Duration("delay", delay), + ) + + // 延迟期间发送 SSE ping(复用 waitForLockWithPing 的 ping 逻辑) + needPing := isStream && h.pingFormat != "" + var flusher http.Flusher + if needPing { + flusher, _ = c.Writer.(http.Flusher) + if flusher == nil { + needPing = false + } + } + + var pingCh <-chan time.Time + if needPing { + pingTicker := time.NewTicker(h.pingInterval) + defer pingTicker.Stop() + pingCh = pingTicker.C + } + + timer := time.NewTimer(delay) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-pingCh: + // SSE ping 逻辑(与 waitForLockWithPing 一致) + if !*streamStarted { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + *streamStarted = true + } + if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil { + return err + } + flusher.Flush() + case <-timer.C: + return nil + } + } +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 7b62149c..76f5a979 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -14,6 +14,7 @@ func ProvideAdminHandlers( groupHandler *admin.GroupHandler, accountHandler *admin.AccountHandler, announcementHandler *admin.AnnouncementHandler, + dataManagementHandler *admin.DataManagementHandler, oauthHandler *admin.OAuthHandler, openaiOAuthHandler *admin.OpenAIOAuthHandler, geminiOAuthHandler *admin.GeminiOAuthHandler, @@ -28,6 +29,7 @@ func ProvideAdminHandlers( usageHandler *admin.UsageHandler, userAttributeHandler *admin.UserAttributeHandler, errorPassthroughHandler *admin.ErrorPassthroughHandler, + apiKeyHandler *admin.AdminAPIKeyHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -35,6 +37,7 @@ func ProvideAdminHandlers( Group: groupHandler, Account: accountHandler, Announcement: announcementHandler, + DataManagement: dataManagementHandler, OAuth: oauthHandler, OpenAIOAuth: openaiOAuthHandler, GeminiOAuth: geminiOAuthHandler, @@ -49,12 +52,13 @@ func ProvideAdminHandlers( Usage: usageHandler, UserAttribute: userAttributeHandler, ErrorPassthrough: errorPassthroughHandler, + APIKey: apiKeyHandler, } } // ProvideSystemHandler creates admin.SystemHandler with UpdateService -func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler { - return admin.NewSystemHandler(updateService) +func ProvideSystemHandler(updateService *service.UpdateService, lockService *service.SystemOperationLockService) *admin.SystemHandler { + return admin.NewSystemHandler(updateService, lockService) } // ProvideSettingHandler creates SettingHandler with version from BuildInfo @@ -74,8 +78,12 @@ func ProvideHandlers( adminHandlers *AdminHandlers, gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, + soraGatewayHandler *SoraGatewayHandler, + soraClientHandler *SoraClientHandler, settingHandler *SettingHandler, totpHandler *TotpHandler, + _ *service.IdempotencyCoordinator, + _ *service.IdempotencyCleanupService, ) *Handlers { return &Handlers{ Auth: authHandler, @@ -88,6 +96,8 @@ func ProvideHandlers( Admin: adminHandlers, Gateway: gatewayHandler, OpenAIGateway: openaiGatewayHandler, + SoraGateway: soraGatewayHandler, + SoraClient: soraClientHandler, Setting: settingHandler, Totp: totpHandler, } @@ -105,6 +115,7 @@ var ProviderSet = wire.NewSet( NewAnnouncementHandler, NewGatewayHandler, NewOpenAIGatewayHandler, + NewSoraGatewayHandler, NewTotpHandler, ProvideSettingHandler, @@ -114,6 +125,7 @@ var ProviderSet = wire.NewSet( admin.NewGroupHandler, admin.NewAccountHandler, admin.NewAnnouncementHandler, + admin.NewDataManagementHandler, admin.NewOAuthHandler, admin.NewOpenAIOAuthHandler, admin.NewGeminiOAuthHandler, @@ -128,6 +140,7 @@ var ProviderSet = wire.NewSet( admin.NewUsageHandler, admin.NewUserAttributeHandler, admin.NewErrorPassthroughHandler, + admin.NewAdminAPIKeyHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/integration/e2e_gateway_test.go b/backend/internal/integration/e2e_gateway_test.go index ec0b29f7..8ee3f22e 100644 --- a/backend/internal/integration/e2e_gateway_test.go +++ b/backend/internal/integration/e2e_gateway_test.go @@ -21,11 +21,18 @@ var ( // - "" (默认): 使用 /v1/messages, /v1beta/models(混合模式,可调度 antigravity 账户) // - "/antigravity": 使用 /antigravity/v1/messages, /antigravity/v1beta/models(非混合模式,仅 antigravity 账户) endpointPrefix = getEnv("ENDPOINT_PREFIX", "") - claudeAPIKey = "sk-8e572bc3b3de92ace4f41f4256c28600ca11805732a7b693b5c44741346bbbb3" - geminiAPIKey = "sk-5950197a2085b38bbe5a1b229cc02b8ece914963fc44cacc06d497ae8b87410f" testInterval = 1 * time.Second // 测试间隔,防止限流 ) +const ( + // 注意:E2E 测试请使用环境变量注入密钥,避免任何凭证进入仓库历史。 + // 例如: + // export CLAUDE_API_KEY="sk-..." + // export GEMINI_API_KEY="sk-..." + claudeAPIKeyEnv = "CLAUDE_API_KEY" + geminiAPIKeyEnv = "GEMINI_API_KEY" +) + func getEnv(key, defaultVal string) string { if v := os.Getenv(key); v != "" { return v @@ -65,16 +72,45 @@ func TestMain(m *testing.M) { if endpointPrefix != "" { mode = "Antigravity 模式" } - fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s)\n\n", baseURL, endpointPrefix, mode) + claudeKeySet := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) != "" + geminiKeySet := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) != "" + fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s, %s=%v, %s=%v)\n\n", + baseURL, + endpointPrefix, + mode, + claudeAPIKeyEnv, + claudeKeySet, + geminiAPIKeyEnv, + geminiKeySet, + ) os.Exit(m.Run()) } +func requireClaudeAPIKey(t *testing.T) string { + t.Helper() + key := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) + if key == "" { + t.Skipf("未设置 %s,跳过 Claude 相关 E2E 测试", claudeAPIKeyEnv) + } + return key +} + +func requireGeminiAPIKey(t *testing.T) string { + t.Helper() + key := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) + if key == "" { + t.Skipf("未设置 %s,跳过 Gemini 相关 E2E 测试", geminiAPIKeyEnv) + } + return key +} + // TestClaudeModelsList 测试 GET /v1/models func TestClaudeModelsList(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) url := baseURL + endpointPrefix + "/v1/models" req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) @@ -106,10 +142,11 @@ func TestClaudeModelsList(t *testing.T) { // TestGeminiModelsList 测试 GET /v1beta/models func TestGeminiModelsList(t *testing.T) { + geminiKey := requireGeminiAPIKey(t) url := baseURL + endpointPrefix + "/v1beta/models" req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Authorization", "Bearer "+geminiAPIKey) + req.Header.Set("Authorization", "Bearer "+geminiKey) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) @@ -137,21 +174,22 @@ func TestGeminiModelsList(t *testing.T) { // TestClaudeMessages 测试 Claude /v1/messages 接口 func TestClaudeMessages(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) for i, model := range claudeModels { if i > 0 { time.Sleep(testInterval) } t.Run(model+"_非流式", func(t *testing.T) { - testClaudeMessage(t, model, false) + testClaudeMessage(t, claudeKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_流式", func(t *testing.T) { - testClaudeMessage(t, model, true) + testClaudeMessage(t, claudeKey, model, true) }) } } -func testClaudeMessage(t *testing.T, model string, stream bool) { +func testClaudeMessage(t *testing.T, claudeKey string, model string, stream bool) { url := baseURL + endpointPrefix + "/v1/messages" payload := map[string]any{ @@ -166,7 +204,7 @@ func testClaudeMessage(t *testing.T, model string, stream bool) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -213,21 +251,22 @@ func testClaudeMessage(t *testing.T, model string, stream bool) { // TestGeminiGenerateContent 测试 Gemini /v1beta/models/:model 接口 func TestGeminiGenerateContent(t *testing.T) { + geminiKey := requireGeminiAPIKey(t) for i, model := range geminiModels { if i > 0 { time.Sleep(testInterval) } t.Run(model+"_非流式", func(t *testing.T) { - testGeminiGenerate(t, model, false) + testGeminiGenerate(t, geminiKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_流式", func(t *testing.T) { - testGeminiGenerate(t, model, true) + testGeminiGenerate(t, geminiKey, model, true) }) } } -func testGeminiGenerate(t *testing.T, model string, stream bool) { +func testGeminiGenerate(t *testing.T, geminiKey string, model string, stream bool) { action := "generateContent" if stream { action = "streamGenerateContent" @@ -254,7 +293,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+geminiAPIKey) + req.Header.Set("Authorization", "Bearer "+geminiKey) client := &http.Client{Timeout: 60 * time.Second} resp, err := client.Do(req) @@ -301,6 +340,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { // TestClaudeMessagesWithComplexTools 测试带复杂工具 schema 的请求 // 模拟 Claude Code 发送的请求,包含需要清理的 JSON Schema 字段 func TestClaudeMessagesWithComplexTools(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) // 测试模型列表(只测试几个代表性模型) models := []string{ "claude-opus-4-5-20251101", // Claude 模型 @@ -312,12 +352,12 @@ func TestClaudeMessagesWithComplexTools(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_复杂工具", func(t *testing.T) { - testClaudeMessageWithTools(t, model) + testClaudeMessageWithTools(t, claudeKey, model) }) } } -func testClaudeMessageWithTools(t *testing.T, model string) { +func testClaudeMessageWithTools(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 构造包含复杂 schema 的工具定义(模拟 Claude Code 的工具) @@ -473,7 +513,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -519,6 +559,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) { // 验证:当历史 assistant 消息包含 tool_use 但没有 signature 时, // 系统应自动添加 dummy thought_signature 避免 Gemini 400 错误 func TestClaudeMessagesWithThinkingAndTools(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) models := []string{ "claude-haiku-4-5-20251001", // gemini-3-flash } @@ -527,12 +568,12 @@ func TestClaudeMessagesWithThinkingAndTools(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_thinking模式工具调用", func(t *testing.T) { - testClaudeThinkingWithToolHistory(t, model) + testClaudeThinkingWithToolHistory(t, claudeKey, model) }) } } -func testClaudeThinkingWithToolHistory(t *testing.T, model string) { +func testClaudeThinkingWithToolHistory(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 模拟历史对话:用户请求 → assistant 调用工具 → 工具返回 → 继续对话 @@ -600,7 +641,7 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -649,6 +690,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { if endpointPrefix != "/antigravity" { t.Skip("仅在 Antigravity 模式下运行") } + claudeKey := requireClaudeAPIKey(t) // 测试通过 Claude 端点调用 Gemini 模型 geminiViaClaude := []string{ @@ -664,11 +706,11 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_通过Claude端点", func(t *testing.T) { - testClaudeMessage(t, model, false) + testClaudeMessage(t, claudeKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_通过Claude端点_流式", func(t *testing.T) { - testClaudeMessage(t, model, true) + testClaudeMessage(t, claudeKey, model, true) }) } } @@ -676,6 +718,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { // TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景 // 验证:Gemini 模型接受没有 signature 的 thinking block func TestClaudeMessagesWithNoSignature(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) models := []string{ "claude-haiku-4-5-20251001", // gemini-3-flash - 支持无 signature } @@ -684,12 +727,12 @@ func TestClaudeMessagesWithNoSignature(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_无signature", func(t *testing.T) { - testClaudeWithNoSignature(t, model) + testClaudeWithNoSignature(t, claudeKey, model) }) } } -func testClaudeWithNoSignature(t *testing.T, model string) { +func testClaudeWithNoSignature(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 模拟历史对话包含 thinking block 但没有 signature @@ -732,7 +775,7 @@ func testClaudeWithNoSignature(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -777,6 +820,7 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) { if endpointPrefix != "/antigravity" { t.Skip("仅在 Antigravity 模式下运行") } + geminiKey := requireGeminiAPIKey(t) // 测试通过 Gemini 端点调用 Claude 模型 claudeViaGemini := []string{ @@ -789,11 +833,11 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_通过Gemini端点", func(t *testing.T) { - testGeminiGenerate(t, model, false) + testGeminiGenerate(t, geminiKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) { - testGeminiGenerate(t, model, true) + testGeminiGenerate(t, geminiKey, model, true) }) } } diff --git a/backend/internal/integration/e2e_helpers_test.go b/backend/internal/integration/e2e_helpers_test.go new file mode 100644 index 00000000..7d266bcb --- /dev/null +++ b/backend/internal/integration/e2e_helpers_test.go @@ -0,0 +1,48 @@ +//go:build e2e + +package integration + +import ( + "os" + "strings" + "testing" +) + +// ============================================================================= +// E2E Mock 模式支持 +// ============================================================================= +// 当 E2E_MOCK=true 时,使用本地 Mock 响应替代真实 API 调用。 +// 这允许在没有真实 API Key 的环境(如 CI)中验证基本的请求/响应流程。 + +// isMockMode 检查是否启用 Mock 模式 +func isMockMode() bool { + return strings.EqualFold(os.Getenv("E2E_MOCK"), "true") +} + +// skipIfNoRealAPI 如果未配置真实 API Key 且不在 Mock 模式,则跳过测试 +func skipIfNoRealAPI(t *testing.T) { + t.Helper() + if isMockMode() { + return // Mock 模式下不跳过 + } + claudeKey := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) + geminiKey := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) + if claudeKey == "" && geminiKey == "" { + t.Skip("未设置 API Key 且未启用 Mock 模式,跳过测试") + } +} + +// ============================================================================= +// API Key 脱敏(Task 6.10) +// ============================================================================= + +// safeLogKey 安全地记录 API Key(仅显示前 8 位) +func safeLogKey(t *testing.T, prefix string, key string) { + t.Helper() + key = strings.TrimSpace(key) + if len(key) <= 8 { + t.Logf("%s: ***(长度: %d)", prefix, len(key)) + return + } + t.Logf("%s: %s...(长度: %d)", prefix, key[:8], len(key)) +} diff --git a/backend/internal/integration/e2e_user_flow_test.go b/backend/internal/integration/e2e_user_flow_test.go new file mode 100644 index 00000000..5489d0a3 --- /dev/null +++ b/backend/internal/integration/e2e_user_flow_test.go @@ -0,0 +1,317 @@ +//go:build e2e + +package integration + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" +) + +// E2E 用户流程测试 +// 测试完整的用户操作链路:注册 → 登录 → 创建 API Key → 调用网关 → 查询用量 + +var ( + testUserEmail = "e2e-test-" + fmt.Sprintf("%d", time.Now().UnixMilli()) + "@test.local" + testUserPassword = "E2eTest@12345" + testUserName = "e2e-test-user" +) + +// TestUserRegistrationAndLogin 测试用户注册和登录流程 +func TestUserRegistrationAndLogin(t *testing.T) { + // 步骤 1: 注册新用户 + t.Run("注册新用户", func(t *testing.T) { + payload := map[string]string{ + "email": testUserEmail, + "password": testUserPassword, + "username": testUserName, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/register", body, "") + if err != nil { + t.Skipf("注册接口不可用,跳过用户流程测试: %v", err) + return + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + // 注册可能返回 200(成功)或 400(邮箱已存在)或 403(注册已关闭) + switch resp.StatusCode { + case 200: + t.Logf("✅ 用户注册成功: %s", testUserEmail) + case 400: + t.Logf("⚠️ 用户可能已存在: %s", string(respBody)) + case 403: + t.Skipf("注册功能已关闭: %s", string(respBody)) + default: + t.Logf("⚠️ 注册返回 HTTP %d: %s(继续尝试登录)", resp.StatusCode, string(respBody)) + } + }) + + // 步骤 2: 登录获取 JWT + var accessToken string + t.Run("用户登录获取JWT", func(t *testing.T) { + payload := map[string]string{ + "email": testUserEmail, + "password": testUserPassword, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/login", body, "") + if err != nil { + t.Fatalf("登录请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + t.Skipf("登录失败 HTTP %d: %s(可能需要先注册用户)", resp.StatusCode, string(respBody)) + return + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + t.Fatalf("解析登录响应失败: %v", err) + } + + // 尝试从标准响应格式获取 token + if token, ok := result["access_token"].(string); ok && token != "" { + accessToken = token + } else if data, ok := result["data"].(map[string]any); ok { + if token, ok := data["access_token"].(string); ok { + accessToken = token + } + } + + if accessToken == "" { + t.Skipf("未获取到 access_token,响应: %s", string(respBody)) + return + } + + // 验证 token 不为空且格式基本正确 + if len(accessToken) < 10 { + t.Fatalf("access_token 格式异常: %s", accessToken) + } + + t.Logf("✅ 登录成功,获取 JWT(长度: %d)", len(accessToken)) + }) + + if accessToken == "" { + t.Skip("未获取到 JWT,跳过后续测试") + return + } + + // 步骤 3: 使用 JWT 获取当前用户信息 + t.Run("获取当前用户信息", func(t *testing.T) { + resp, err := doRequest(t, "GET", "/api/user/me", nil, accessToken) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + t.Logf("✅ 成功获取用户信息") + }) +} + +// TestAPIKeyLifecycle 测试 API Key 的创建和使用 +func TestAPIKeyLifecycle(t *testing.T) { + // 先登录获取 JWT + accessToken := loginTestUser(t) + if accessToken == "" { + t.Skip("无法登录,跳过 API Key 生命周期测试") + return + } + + var apiKey string + + // 步骤 1: 创建 API Key + t.Run("创建API_Key", func(t *testing.T) { + payload := map[string]string{ + "name": "e2e-test-key-" + fmt.Sprintf("%d", time.Now().UnixMilli()), + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/keys", body, accessToken) + if err != nil { + t.Fatalf("创建 API Key 请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + t.Skipf("创建 API Key 失败 HTTP %d: %s", resp.StatusCode, string(respBody)) + return + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + // 从响应中提取 key + if key, ok := result["key"].(string); ok { + apiKey = key + } else if data, ok := result["data"].(map[string]any); ok { + if key, ok := data["key"].(string); ok { + apiKey = key + } + } + + if apiKey == "" { + t.Skipf("未获取到 API Key,响应: %s", string(respBody)) + return + } + + // 验证 API Key 脱敏日志(只显示前 8 位) + masked := apiKey + if len(masked) > 8 { + masked = masked[:8] + "..." + } + t.Logf("✅ API Key 创建成功: %s", masked) + }) + + if apiKey == "" { + t.Skip("未创建 API Key,跳过后续测试") + return + } + + // 步骤 2: 使用 API Key 调用网关(需要 Claude 或 Gemini 可用) + t.Run("使用API_Key调用网关", func(t *testing.T) { + // 尝试调用 models 列表(最轻量的 API 调用) + resp, err := doRequest(t, "GET", "/v1/models", nil, apiKey) + if err != nil { + t.Fatalf("网关请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + // 可能返回 200(成功)或 402(余额不足)或 403(无可用账户) + switch { + case resp.StatusCode == 200: + t.Logf("✅ API Key 网关调用成功") + case resp.StatusCode == 402: + t.Logf("⚠️ 余额不足,但 API Key 认证通过") + case resp.StatusCode == 403: + t.Logf("⚠️ 无可用账户,但 API Key 认证通过") + default: + t.Logf("⚠️ 网关返回 HTTP %d: %s", resp.StatusCode, string(respBody)) + } + }) + + // 步骤 3: 查询用量记录 + t.Run("查询用量记录", func(t *testing.T) { + resp, err := doRequest(t, "GET", "/api/usage/dashboard", nil, accessToken) + if err != nil { + t.Fatalf("用量查询请求失败: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + t.Logf("⚠️ 用量查询返回 HTTP %d: %s", resp.StatusCode, string(body)) + return + } + + t.Logf("✅ 用量查询成功") + }) +} + +// ============================================================================= +// 辅助函数 +// ============================================================================= + +func doRequest(t *testing.T, method, path string, body []byte, token string) (*http.Response, error) { + t.Helper() + + url := baseURL + path + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } + + req, err := http.NewRequest(method, url, bodyReader) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + client := &http.Client{Timeout: 30 * time.Second} + return client.Do(req) +} + +func loginTestUser(t *testing.T) string { + t.Helper() + + // 先尝试用管理员账户登录 + adminEmail := getEnv("ADMIN_EMAIL", "admin@sub2api.local") + adminPassword := getEnv("ADMIN_PASSWORD", "") + + if adminPassword == "" { + // 尝试用测试用户 + adminEmail = testUserEmail + adminPassword = testUserPassword + } + + payload := map[string]string{ + "email": adminEmail, + "password": adminPassword, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/login", body, "") + if err != nil { + return "" + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return "" + } + + respBody, _ := io.ReadAll(resp.Body) + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if token, ok := result["access_token"].(string); ok { + return token + } + if data, ok := result["data"].(map[string]any); ok { + if token, ok := data["access_token"].(string); ok { + return token + } + } + + return "" +} + +// redactAPIKey API Key 脱敏,只显示前 8 位 +func redactAPIKey(key string) string { + key = strings.TrimSpace(key) + if len(key) <= 8 { + return "***" + } + return key[:8] + "..." +} diff --git a/backend/internal/middleware/rate_limiter_test.go b/backend/internal/middleware/rate_limiter_test.go index 0c379c0f..e362274f 100644 --- a/backend/internal/middleware/rate_limiter_test.go +++ b/backend/internal/middleware/rate_limiter_test.go @@ -60,6 +60,49 @@ func TestRateLimiterFailureModes(t *testing.T) { require.Equal(t, http.StatusTooManyRequests, recorder.Code) } +func TestRateLimiterDifferentIPsIndependent(t *testing.T) { + gin.SetMode(gin.TestMode) + + callCounts := make(map[string]int64) + originalRun := rateLimitRun + rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) { + callCounts[key]++ + return callCounts[key], false, nil + } + t.Cleanup(func() { + rateLimitRun = originalRun + }) + + limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"})) + + router := gin.New() + router.Use(limiter.Limit("api", 1, time.Second)) + router.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + // 第一个 IP 的请求应通过 + req1 := httptest.NewRequest(http.MethodGet, "/test", nil) + req1.RemoteAddr = "10.0.0.1:1234" + rec1 := httptest.NewRecorder() + router.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code, "第一个 IP 的第一次请求应通过") + + // 第二个 IP 的请求应独立通过(不受第一个 IP 的计数影响) + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2.RemoteAddr = "10.0.0.2:5678" + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code, "第二个 IP 的第一次请求应独立通过") + + // 第一个 IP 的第二次请求应被限流 + req3 := httptest.NewRequest(http.MethodGet, "/test", nil) + req3.RemoteAddr = "10.0.0.1:1234" + rec3 := httptest.NewRecorder() + router.ServeHTTP(rec3, req3) + require.Equal(t, http.StatusTooManyRequests, rec3.Code, "第一个 IP 的第二次请求应被限流") +} + func TestRateLimiterSuccessAndLimit(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/model/error_passthrough_rule.go b/backend/internal/model/error_passthrough_rule.go index d4fc16e3..620736cd 100644 --- a/backend/internal/model/error_passthrough_rule.go +++ b/backend/internal/model/error_passthrough_rule.go @@ -18,6 +18,7 @@ type ErrorPassthroughRule struct { ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用) PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息 CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用) + SkipMonitoring bool `json:"skip_monitoring"` // 是否跳过运维监控记录 Description *string `json:"description"` // 规则描述 CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 8a29cd10..7cc68060 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -27,7 +27,7 @@ type ClaudeMessage struct { // ThinkingConfig Thinking 配置 type ThinkingConfig struct { - Type string `json:"type"` // "enabled" or "disabled" + Type string `json:"type"` // "enabled" / "adaptive" / "disabled" BudgetTokens int `json:"budget_tokens,omitempty"` // thinking budget } @@ -151,6 +151,9 @@ var claudeModels = []modelDef{ {ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"}, {ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"}, {ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"}, + {ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"}, + {ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"}, + {ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"}, } // Antigravity 支持的 Gemini 模型 @@ -161,6 +164,10 @@ var geminiModels = []modelDef{ {ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"}, {ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"}, {ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"}, + {ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"}, + {ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"}, + {ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"}, + {ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"}, {ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"}, {ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"}, } diff --git a/backend/internal/pkg/antigravity/claude_types_test.go b/backend/internal/pkg/antigravity/claude_types_test.go new file mode 100644 index 00000000..f7cb0a24 --- /dev/null +++ b/backend/internal/pkg/antigravity/claude_types_test.go @@ -0,0 +1,26 @@ +package antigravity + +import "testing" + +func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) { + t.Parallel() + + models := DefaultModels() + byID := make(map[string]ClaudeModel, len(models)) + for _, m := range models { + byID[m.ID] = m + } + + requiredIDs := []string{ + "claude-opus-4-6-thinking", + "gemini-3.1-flash-image", + "gemini-3.1-flash-image-preview", + "gemini-3-pro-image", // legacy compatibility + } + + for _, id := range requiredIDs { + if _, ok := byID[id]; !ok { + t.Fatalf("expected model %q to be exposed in DefaultModels", id) + } + } +} diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index a6279b11..d46bbc45 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -14,6 +14,9 @@ import ( "net/url" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" ) // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) @@ -33,7 +36,7 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri // 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("User-Agent", UserAgent) + req.Header.Set("User-Agent", GetUserAgent()) return req, nil } @@ -115,6 +118,23 @@ type LoadCodeAssistResponse struct { IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"` } +// OnboardUserRequest onboardUser 请求 +type OnboardUserRequest struct { + TierID string `json:"tierId"` + Metadata struct { + IDEType string `json:"ideType"` + Platform string `json:"platform,omitempty"` + PluginType string `json:"pluginType,omitempty"` + } `json:"metadata"` +} + +// OnboardUserResponse onboardUser 响应 +type OnboardUserResponse struct { + Name string `json:"name,omitempty"` + Done bool `json:"done"` + Response map[string]any `json:"response,omitempty"` +} + // GetTier 获取账户类型 // 优先返回 paidTier(付费订阅级别),否则返回 currentTier func (r *LoadCodeAssistResponse) GetTier() string { @@ -132,22 +152,26 @@ type Client struct { httpClient *http.Client } -func NewClient(proxyURL string) *Client { +func NewClient(proxyURL string) (*Client, error) { client := &http.Client{ Timeout: 30 * time.Second, } - if strings.TrimSpace(proxyURL) != "" { - if proxyURLParsed, err := url.Parse(proxyURL); err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURLParsed), - } + _, parsed, err := proxyurl.Parse(proxyURL) + if err != nil { + return nil, err + } + if parsed != nil { + transport := &http.Transport{} + if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { + return nil, fmt.Errorf("configure proxy: %w", err) } + client.Transport = transport } return &Client{ httpClient: client, - } + }, nil } // isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) @@ -187,9 +211,14 @@ func shouldFallbackToNextURL(err error, statusCode int) bool { // ExchangeCode 用 authorization code 交换 token func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) { + clientSecret, err := getClientSecret() + if err != nil { + return nil, err + } + params := url.Values{} params.Set("client_id", ClientID) - params.Set("client_secret", ClientSecret) + params.Set("client_secret", clientSecret) params.Set("code", code) params.Set("redirect_uri", RedirectURI) params.Set("grant_type", "authorization_code") @@ -226,9 +255,14 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (* // RefreshToken 刷新 access_token func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) { + clientSecret, err := getClientSecret() + if err != nil { + return nil, err + } + params := url.Values{} params.Set("client_id", ClientID) - params.Set("client_secret", ClientSecret) + params.Set("client_secret", clientSecret) params.Set("refresh_token", refreshToken) params.Set("grant_type", "refresh_token") @@ -316,7 +350,7 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC } req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) + req.Header.Set("User-Agent", GetUserAgent()) resp, err := c.httpClient.Do(req) if err != nil { @@ -361,6 +395,117 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC return nil, nil, lastErr } +// OnboardUser 触发账号 onboarding,并返回 project_id +// 说明: +// 1) 部分账号 loadCodeAssist 不会立即返回 cloudaicompanionProject; +// 2) 这时需要调用 onboardUser 完成初始化,之后才能拿到 project_id。 +func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) { + tierID = strings.TrimSpace(tierID) + if tierID == "" { + return "", fmt.Errorf("tier_id 为空") + } + + reqBody := OnboardUserRequest{TierID: tierID} + reqBody.Metadata.IDEType = "ANTIGRAVITY" + reqBody.Metadata.Platform = "PLATFORM_UNSPECIFIED" + reqBody.Metadata.PluginType = "GEMINI" + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("序列化请求失败: %w", err) + } + + availableURLs := BaseURLs + var lastErr error + + for urlIdx, baseURL := range availableURLs { + apiURL := baseURL + "/v1internal:onboardUser" + + for attempt := 1; attempt <= 5; attempt++ { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes)) + if err != nil { + lastErr = fmt.Errorf("创建请求失败: %w", err) + break + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", GetUserAgent()) + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("onboardUser 请求失败: %w", err) + if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + log.Printf("[antigravity] onboardUser URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + break + } + return "", lastErr + } + + respBodyBytes, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + return "", fmt.Errorf("读取响应失败: %w", err) + } + + if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + log.Printf("[antigravity] onboardUser URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + break + } + + if resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("onboardUser 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + return "", lastErr + } + + var onboardResp OnboardUserResponse + if err := json.Unmarshal(respBodyBytes, &onboardResp); err != nil { + lastErr = fmt.Errorf("onboardUser 响应解析失败: %w", err) + return "", lastErr + } + + if onboardResp.Done { + if projectID := extractProjectIDFromOnboardResponse(onboardResp.Response); projectID != "" { + DefaultURLAvailability.MarkSuccess(baseURL) + return projectID, nil + } + lastErr = fmt.Errorf("onboardUser 完成但未返回 project_id") + return "", lastErr + } + + // done=false 时等待后重试(与 CLIProxyAPI 行为一致) + select { + case <-time.After(2 * time.Second): + case <-ctx.Done(): + return "", ctx.Err() + } + } + } + + if lastErr != nil { + return "", lastErr + } + return "", fmt.Errorf("onboardUser 未返回 project_id") +} + +func extractProjectIDFromOnboardResponse(resp map[string]any) string { + if len(resp) == 0 { + return "" + } + + if v, ok := resp["cloudaicompanionProject"]; ok { + switch project := v.(type) { + case string: + return strings.TrimSpace(project) + case map[string]any: + if id, ok := project["id"].(string); ok { + return strings.TrimSpace(id) + } + } + } + + return "" +} + // ModelQuotaInfo 模型配额信息 type ModelQuotaInfo struct { RemainingFraction float64 `json:"remainingFraction"` @@ -404,7 +549,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI } req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) + req.Header.Set("User-Agent", GetUserAgent()) resp, err := c.httpClient.Do(req) if err != nil { diff --git a/backend/internal/pkg/antigravity/client_test.go b/backend/internal/pkg/antigravity/client_test.go new file mode 100644 index 00000000..20b57833 --- /dev/null +++ b/backend/internal/pkg/antigravity/client_test.go @@ -0,0 +1,1770 @@ +//go:build unit + +package antigravity + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +// --------------------------------------------------------------------------- +// NewAPIRequestWithURL +// --------------------------------------------------------------------------- + +func TestNewAPIRequestWithURL_普通请求(t *testing.T) { + ctx := context.Background() + baseURL := "https://example.com" + action := "generateContent" + token := "test-token" + body := []byte(`{"prompt":"hello"}`) + + req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + + // 验证 URL 不含 ?alt=sse + expectedURL := "https://example.com/v1internal:generateContent" + if req.URL.String() != expectedURL { + t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL) + } + + // 验证请求方法 + if req.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", req.Method) + } + + // 验证 Headers + if ct := req.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if auth := req.Header.Get("Authorization"); auth != "Bearer test-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + if ua := req.Header.Get("User-Agent"); ua != GetUserAgent() { + t.Errorf("User-Agent 不匹配: got %s, want %s", ua, GetUserAgent()) + } +} + +func TestNewAPIRequestWithURL_流式请求(t *testing.T) { + ctx := context.Background() + baseURL := "https://example.com" + action := "streamGenerateContent" + token := "tok" + body := []byte(`{}`) + + req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + + expectedURL := "https://example.com/v1internal:streamGenerateContent?alt=sse" + if req.URL.String() != expectedURL { + t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL) + } +} + +func TestNewAPIRequestWithURL_空Body(t *testing.T) { + ctx := context.Background() + req, err := NewAPIRequestWithURL(ctx, "https://example.com", "test", "tok", nil) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + if req.Body == nil { + t.Error("Body 应该非 nil(bytes.NewReader(nil) 会返回空 reader)") + } +} + +// --------------------------------------------------------------------------- +// NewAPIRequest +// --------------------------------------------------------------------------- + +func TestNewAPIRequest_使用默认URL(t *testing.T) { + ctx := context.Background() + req, err := NewAPIRequest(ctx, "generateContent", "tok", []byte(`{}`)) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + + expected := BaseURL + "/v1internal:generateContent" + if req.URL.String() != expected { + t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expected) + } +} + +// --------------------------------------------------------------------------- +// TierInfo.UnmarshalJSON +// --------------------------------------------------------------------------- + +func TestTierInfo_UnmarshalJSON_字符串格式(t *testing.T) { + data := []byte(`"free-tier"`) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if tier.ID != "free-tier" { + t.Errorf("ID 不匹配: got %s, want free-tier", tier.ID) + } + if tier.Name != "" { + t.Errorf("Name 应为空: got %s", tier.Name) + } +} + +func TestTierInfo_UnmarshalJSON_对象格式(t *testing.T) { + data := []byte(`{"id":"g1-pro-tier","name":"Pro","description":"Pro plan"}`) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if tier.ID != "g1-pro-tier" { + t.Errorf("ID 不匹配: got %s, want g1-pro-tier", tier.ID) + } + if tier.Name != "Pro" { + t.Errorf("Name 不匹配: got %s, want Pro", tier.Name) + } + if tier.Description != "Pro plan" { + t.Errorf("Description 不匹配: got %s, want Pro plan", tier.Description) + } +} + +func TestTierInfo_UnmarshalJSON_null(t *testing.T) { + data := []byte(`null`) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化 null 失败: %v", err) + } + if tier.ID != "" { + t.Errorf("null 场景下 ID 应为空: got %s", tier.ID) + } +} + +func TestTierInfo_UnmarshalJSON_空数据(t *testing.T) { + data := []byte(``) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化空数据失败: %v", err) + } + if tier.ID != "" { + t.Errorf("空数据场景下 ID 应为空: got %s", tier.ID) + } +} + +func TestTierInfo_UnmarshalJSON_空格包裹null(t *testing.T) { + data := []byte(` null `) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化空格 null 失败: %v", err) + } + if tier.ID != "" { + t.Errorf("空格 null 场景下 ID 应为空: got %s", tier.ID) + } +} + +func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) { + // 模拟 LoadCodeAssistResponse 中的嵌套反序列化 + jsonData := `{"currentTier":"free-tier","paidTier":{"id":"g1-ultra-tier","name":"Ultra"}}` + var resp LoadCodeAssistResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("反序列化嵌套结构失败: %v", err) + } + if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" { + t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier) + } + if resp.PaidTier == nil || resp.PaidTier.ID != "g1-ultra-tier" { + t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier) + } +} + +// --------------------------------------------------------------------------- +// LoadCodeAssistResponse.GetTier +// --------------------------------------------------------------------------- + +func TestGetTier_PaidTier优先(t *testing.T) { + resp := &LoadCodeAssistResponse{ + CurrentTier: &TierInfo{ID: "free-tier"}, + PaidTier: &TierInfo{ID: "g1-pro-tier"}, + } + if got := resp.GetTier(); got != "g1-pro-tier" { + t.Errorf("应返回 paidTier: got %s", got) + } +} + +func TestGetTier_回退到CurrentTier(t *testing.T) { + resp := &LoadCodeAssistResponse{ + CurrentTier: &TierInfo{ID: "free-tier"}, + } + if got := resp.GetTier(); got != "free-tier" { + t.Errorf("应返回 currentTier: got %s", got) + } +} + +func TestGetTier_PaidTier为空ID(t *testing.T) { + resp := &LoadCodeAssistResponse{ + CurrentTier: &TierInfo{ID: "free-tier"}, + PaidTier: &TierInfo{ID: ""}, + } + // paidTier.ID 为空时应回退到 currentTier + if got := resp.GetTier(); got != "free-tier" { + t.Errorf("paidTier.ID 为空时应回退到 currentTier: got %s", got) + } +} + +func TestGetTier_两者都为nil(t *testing.T) { + resp := &LoadCodeAssistResponse{} + if got := resp.GetTier(); got != "" { + t.Errorf("两者都为 nil 时应返回空字符串: got %s", got) + } +} + +// --------------------------------------------------------------------------- +// NewClient +// --------------------------------------------------------------------------- + +func mustNewClient(t *testing.T, proxyURL string) *Client { + t.Helper() + client, err := NewClient(proxyURL) + if err != nil { + t.Fatalf("NewClient(%q) failed: %v", proxyURL, err) + } + return client +} + +func TestNewClient_无代理(t *testing.T) { + client, err := NewClient("") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } + if client == nil { + t.Fatal("NewClient 返回 nil") + } + if client.httpClient == nil { + t.Fatal("httpClient 为 nil") + } + if client.httpClient.Timeout != 30*time.Second { + t.Errorf("Timeout 不匹配: got %v, want 30s", client.httpClient.Timeout) + } + // 无代理时 Transport 应为 nil(使用默认) + if client.httpClient.Transport != nil { + t.Error("无代理时 Transport 应为 nil") + } +} + +func TestNewClient_有代理(t *testing.T) { + client, err := NewClient("http://proxy.example.com:8080") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } + if client == nil { + t.Fatal("NewClient 返回 nil") + } + if client.httpClient.Transport == nil { + t.Fatal("有代理时 Transport 不应为 nil") + } +} + +func TestNewClient_空格代理(t *testing.T) { + client, err := NewClient(" ") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } + if client == nil { + t.Fatal("NewClient 返回 nil") + } + // 空格代理应等同于无代理 + if client.httpClient.Transport != nil { + t.Error("空格代理 Transport 应为 nil") + } +} + +func TestNewClient_无效代理URL(t *testing.T) { + // 无效 URL 应返回 error + _, err := NewClient("://invalid") + if err == nil { + t.Fatal("无效代理 URL 应返回错误") + } + if !strings.Contains(err.Error(), "invalid proxy URL") { + t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error()) + } +} + +// --------------------------------------------------------------------------- +// isConnectionError +// --------------------------------------------------------------------------- + +func TestIsConnectionError_nil(t *testing.T) { + if isConnectionError(nil) { + t.Error("nil 错误不应判定为连接错误") + } +} + +func TestIsConnectionError_超时错误(t *testing.T) { + // 使用 net.OpError 包装超时 + err := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: &timeoutError{}, + } + if !isConnectionError(err) { + t.Error("超时错误应判定为连接错误") + } +} + +// timeoutError 实现 net.Error 接口用于测试 +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } + +func TestIsConnectionError_netOpError(t *testing.T) { + err := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: fmt.Errorf("connection refused"), + } + if !isConnectionError(err) { + t.Error("net.OpError 应判定为连接错误") + } +} + +func TestIsConnectionError_urlError(t *testing.T) { + err := &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: fmt.Errorf("some error"), + } + if !isConnectionError(err) { + t.Error("url.Error 应判定为连接错误") + } +} + +func TestIsConnectionError_普通错误(t *testing.T) { + err := fmt.Errorf("some random error") + if isConnectionError(err) { + t.Error("普通错误不应判定为连接错误") + } +} + +func TestIsConnectionError_包装的netOpError(t *testing.T) { + inner := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: fmt.Errorf("connection refused"), + } + err := fmt.Errorf("wrapping: %w", inner) + if !isConnectionError(err) { + t.Error("被包装的 net.OpError 应判定为连接错误") + } +} + +// --------------------------------------------------------------------------- +// shouldFallbackToNextURL +// --------------------------------------------------------------------------- + +func TestShouldFallbackToNextURL_连接错误(t *testing.T) { + err := &net.OpError{Op: "dial", Net: "tcp", Err: fmt.Errorf("refused")} + if !shouldFallbackToNextURL(err, 0) { + t.Error("连接错误应触发 URL 降级") + } +} + +func TestShouldFallbackToNextURL_状态码(t *testing.T) { + tests := []struct { + name string + statusCode int + want bool + }{ + {"429 Too Many Requests", http.StatusTooManyRequests, true}, + {"408 Request Timeout", http.StatusRequestTimeout, true}, + {"404 Not Found", http.StatusNotFound, true}, + {"500 Internal Server Error", http.StatusInternalServerError, true}, + {"502 Bad Gateway", http.StatusBadGateway, true}, + {"503 Service Unavailable", http.StatusServiceUnavailable, true}, + {"200 OK", http.StatusOK, false}, + {"201 Created", http.StatusCreated, false}, + {"400 Bad Request", http.StatusBadRequest, false}, + {"401 Unauthorized", http.StatusUnauthorized, false}, + {"403 Forbidden", http.StatusForbidden, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := shouldFallbackToNextURL(nil, tt.statusCode) + if got != tt.want { + t.Errorf("shouldFallbackToNextURL(nil, %d) = %v, want %v", tt.statusCode, got, tt.want) + } + }) + } +} + +func TestShouldFallbackToNextURL_无错误且200(t *testing.T) { + if shouldFallbackToNextURL(nil, http.StatusOK) { + t.Error("无错误且 200 不应触发 URL 降级") + } +} + +// --------------------------------------------------------------------------- +// Client.ExchangeCode (使用 httptest) +// --------------------------------------------------------------------------- + +func TestClient_ExchangeCode_成功(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求方法 + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s", r.Method) + } + // 验证 Content-Type + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + // 验证请求体参数 + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("client_id") != ClientID { + t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "test-secret" { + t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) + } + if r.FormValue("code") != "auth-code" { + t.Errorf("code 不匹配: got %s", r.FormValue("code")) + } + if r.FormValue("code_verifier") != "verifier123" { + t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier")) + } + if r.FormValue("grant_type") != "authorization_code" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "access-tok", + ExpiresIn: 3600, + TokenType: "Bearer", + RefreshToken: "refresh-tok", + }) + })) + defer server.Close() + + // 临时替换 TokenURL(该函数直接使用常量,需要我们通过构建自定义 client 来绕过) + // 由于 ExchangeCode 硬编码了 TokenURL,我们需要直接测试 HTTP client 的行为 + // 这里通过构造一个直接调用 mock server 的测试 + client := &Client{httpClient: server.Client()} + + // 由于 ExchangeCode 使用硬编码的 TokenURL,我们无法直接注入 mock server URL + // 需要使用 httptest 的 Transport 重定向 + originalTokenURL := TokenURL + // 我们改为直接构造请求来测试逻辑 + _ = originalTokenURL + _ = client + + // 改用直接构造请求测试 mock server 响应 + ctx := context.Background() + params := url.Values{} + params.Set("client_id", ClientID) + params.Set("client_secret", "test-secret") + params.Set("code", "auth-code") + params.Set("redirect_uri", RedirectURI) + params.Set("grant_type", "authorization_code") + params.Set("code_verifier", "verifier123") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode())) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("状态码不匹配: got %d", resp.StatusCode) + } + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + t.Fatalf("解码失败: %v", err) + } + if tokenResp.AccessToken != "access-tok" { + t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken) + } + if tokenResp.RefreshToken != "refresh-tok" { + t.Errorf("RefreshToken 不匹配: got %s", tokenResp.RefreshToken) + } +} + +func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) + + client := mustNewClient(t, "") + _, err := client.ExchangeCode(context.Background(), "code", "verifier") + if err == nil { + t.Fatal("缺少 client_secret 时应返回错误") + } + if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) { + t.Errorf("错误信息应包含环境变量名: got %s", err.Error()) + } +} + +func TestClient_ExchangeCode_服务器返回错误(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant"}`)) + })) + defer server.Close() + + // 直接测试 mock server 的错误响应 + resp, err := server.Client().Get(server.URL) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("状态码不匹配: got %d, want 400", resp.StatusCode) + } +} + +// --------------------------------------------------------------------------- +// Client.RefreshToken (使用 httptest) +// --------------------------------------------------------------------------- + +func TestClient_RefreshToken_MockServer(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s", r.Method) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("grant_type") != "refresh_token" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + if r.FormValue("refresh_token") != "old-refresh-tok" { + t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "new-access-tok", + ExpiresIn: 3600, + TokenType: "Bearer", + }) + })) + defer server.Close() + + ctx := context.Background() + params := url.Values{} + params.Set("client_id", ClientID) + params.Set("client_secret", "test-secret") + params.Set("refresh_token", "old-refresh-tok") + params.Set("grant_type", "refresh_token") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode())) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("状态码不匹配: got %d", resp.StatusCode) + } + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + t.Fatalf("解码失败: %v", err) + } + if tokenResp.AccessToken != "new-access-tok" { + t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken) + } +} + +func TestClient_RefreshToken_无ClientSecret(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) + + client := mustNewClient(t, "") + _, err := client.RefreshToken(context.Background(), "refresh-tok") + if err == nil { + t.Fatal("缺少 client_secret 时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.GetUserInfo (使用 httptest) +// --------------------------------------------------------------------------- + +func TestClient_GetUserInfo_成功(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("请求方法不匹配: got %s", r.Method) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-access-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(UserInfo{ + Email: "user@example.com", + Name: "Test User", + GivenName: "Test", + FamilyName: "User", + Picture: "https://example.com/photo.jpg", + }) + })) + defer server.Close() + + // 直接通过 mock server 测试 GetUserInfo 的行为逻辑 + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + req.Header.Set("Authorization", "Bearer test-access-token") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("状态码不匹配: got %d", resp.StatusCode) + } + + var userInfo UserInfo + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + t.Fatalf("解码失败: %v", err) + } + if userInfo.Email != "user@example.com" { + t.Errorf("Email 不匹配: got %s", userInfo.Email) + } + if userInfo.Name != "Test User" { + t.Errorf("Name 不匹配: got %s", userInfo.Name) + } +} + +func TestClient_GetUserInfo_服务器返回错误(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) + })) + defer server.Close() + + resp, err := server.Client().Get(server.URL) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("状态码不匹配: got %d, want 401", resp.StatusCode) + } +} + +// --------------------------------------------------------------------------- +// TokenResponse / UserInfo JSON 序列化 +// --------------------------------------------------------------------------- + +func TestTokenResponse_JSON序列化(t *testing.T) { + jsonData := `{"access_token":"at","expires_in":3600,"token_type":"Bearer","scope":"openid","refresh_token":"rt"}` + var resp TokenResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if resp.AccessToken != "at" { + t.Errorf("AccessToken 不匹配: got %s", resp.AccessToken) + } + if resp.ExpiresIn != 3600 { + t.Errorf("ExpiresIn 不匹配: got %d", resp.ExpiresIn) + } + if resp.RefreshToken != "rt" { + t.Errorf("RefreshToken 不匹配: got %s", resp.RefreshToken) + } +} + +func TestUserInfo_JSON序列化(t *testing.T) { + jsonData := `{"email":"a@b.com","name":"Alice"}` + var info UserInfo + if err := json.Unmarshal([]byte(jsonData), &info); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if info.Email != "a@b.com" { + t.Errorf("Email 不匹配: got %s", info.Email) + } + if info.Name != "Alice" { + t.Errorf("Name 不匹配: got %s", info.Name) + } +} + +// --------------------------------------------------------------------------- +// LoadCodeAssistResponse JSON 序列化 +// --------------------------------------------------------------------------- + +func TestLoadCodeAssistResponse_完整JSON(t *testing.T) { + jsonData := `{ + "cloudaicompanionProject": "proj-123", + "currentTier": "free-tier", + "paidTier": {"id": "g1-pro-tier", "name": "Pro"}, + "ineligibleTiers": [{"tier": {"id": "g1-ultra-tier"}, "reasonCode": "INELIGIBLE_ACCOUNT"}] + }` + var resp LoadCodeAssistResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if resp.CloudAICompanionProject != "proj-123" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } + if resp.GetTier() != "g1-pro-tier" { + t.Errorf("GetTier 不匹配: got %s", resp.GetTier()) + } + if len(resp.IneligibleTiers) != 1 { + t.Fatalf("IneligibleTiers 数量不匹配: got %d", len(resp.IneligibleTiers)) + } + if resp.IneligibleTiers[0].ReasonCode != "INELIGIBLE_ACCOUNT" { + t.Errorf("ReasonCode 不匹配: got %s", resp.IneligibleTiers[0].ReasonCode) + } +} + +// =========================================================================== +// 以下为新增测试:真正调用 Client 方法,通过 RoundTripper 拦截 HTTP 请求 +// =========================================================================== + +// redirectRoundTripper 将请求中特定前缀的 URL 重定向到 httptest server +type redirectRoundTripper struct { + // 原始 URL 前缀 -> 替换目标 URL 的映射 + redirects map[string]string + transport http.RoundTripper +} + +func (rt *redirectRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + originalURL := req.URL.String() + for prefix, target := range rt.redirects { + if strings.HasPrefix(originalURL, prefix) { + newURL := target + strings.TrimPrefix(originalURL, prefix) + parsed, err := url.Parse(newURL) + if err != nil { + return nil, err + } + req.URL = parsed + break + } + } + if rt.transport == nil { + return http.DefaultTransport.RoundTrip(req) + } + return rt.transport.RoundTrip(req) +} + +// newTestClientWithRedirect 创建一个 Client,将指定 URL 前缀的请求重定向到 mock server +func newTestClientWithRedirect(redirects map[string]string) *Client { + return &Client{ + httpClient: &http.Client{ + Timeout: 10 * time.Second, + Transport: &redirectRoundTripper{ + redirects: redirects, + }, + }, + } +} + +// --------------------------------------------------------------------------- +// Client.ExchangeCode - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_ExchangeCode_Success_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("client_id") != ClientID { + t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "test-secret" { + t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) + } + if r.FormValue("code") != "test-auth-code" { + t.Errorf("code 不匹配: got %s", r.FormValue("code")) + } + if r.FormValue("code_verifier") != "test-verifier" { + t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier")) + } + if r.FormValue("grant_type") != "authorization_code" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + if r.FormValue("redirect_uri") != RedirectURI { + t.Errorf("redirect_uri 不匹配: got %s", r.FormValue("redirect_uri")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "new-access-token", + ExpiresIn: 3600, + TokenType: "Bearer", + Scope: "openid email", + RefreshToken: "new-refresh-token", + }) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + tokenResp, err := client.ExchangeCode(context.Background(), "test-auth-code", "test-verifier") + if err != nil { + t.Fatalf("ExchangeCode 失败: %v", err) + } + if tokenResp.AccessToken != "new-access-token" { + t.Errorf("AccessToken 不匹配: got %s, want new-access-token", tokenResp.AccessToken) + } + if tokenResp.RefreshToken != "new-refresh-token" { + t.Errorf("RefreshToken 不匹配: got %s, want new-refresh-token", tokenResp.RefreshToken) + } + if tokenResp.ExpiresIn != 3600 { + t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn) + } + if tokenResp.TokenType != "Bearer" { + t.Errorf("TokenType 不匹配: got %s, want Bearer", tokenResp.TokenType) + } + if tokenResp.Scope != "openid email" { + t.Errorf("Scope 不匹配: got %s, want openid email", tokenResp.Scope) + } +} + +func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"code expired"}`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.ExchangeCode(context.Background(), "expired-code", "verifier") + if err == nil { + t.Fatal("服务器返回 400 时应返回错误") + } + if !strings.Contains(err.Error(), "token 交换失败") { + t.Errorf("错误信息应包含 'token 交换失败': got %s", err.Error()) + } + if !strings.Contains(err.Error(), "400") { + t.Errorf("错误信息应包含状态码 400: got %s", err.Error()) + } +} + +func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{invalid json`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.ExchangeCode(context.Background(), "code", "verifier") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "token 解析失败") { + t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error()) + } +} + +func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) // 模拟慢响应 + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + _, err := client.ExchangeCode(ctx, "code", "verifier") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.RefreshToken - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_RefreshToken_Success_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("grant_type") != "refresh_token" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + if r.FormValue("refresh_token") != "my-refresh-token" { + t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token")) + } + if r.FormValue("client_id") != ClientID { + t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "test-secret" { + t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "refreshed-access-token", + ExpiresIn: 3600, + TokenType: "Bearer", + }) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + tokenResp, err := client.RefreshToken(context.Background(), "my-refresh-token") + if err != nil { + t.Fatalf("RefreshToken 失败: %v", err) + } + if tokenResp.AccessToken != "refreshed-access-token" { + t.Errorf("AccessToken 不匹配: got %s, want refreshed-access-token", tokenResp.AccessToken) + } + if tokenResp.ExpiresIn != 3600 { + t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn) + } +} + +func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"token revoked"}`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.RefreshToken(context.Background(), "revoked-token") + if err == nil { + t.Fatal("服务器返回 401 时应返回错误") + } + if !strings.Contains(err.Error(), "token 刷新失败") { + t.Errorf("错误信息应包含 'token 刷新失败': got %s", err.Error()) + } +} + +func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`not-json`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.RefreshToken(context.Background(), "refresh-tok") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "token 解析失败") { + t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error()) + } +} + +func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := client.RefreshToken(ctx, "refresh-tok") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.GetUserInfo - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_GetUserInfo_Success_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("请求方法不匹配: got %s, want GET", r.Method) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer user-access-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(UserInfo{ + Email: "test@example.com", + Name: "Test User", + GivenName: "Test", + FamilyName: "User", + Picture: "https://example.com/avatar.jpg", + }) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + userInfo, err := client.GetUserInfo(context.Background(), "user-access-token") + if err != nil { + t.Fatalf("GetUserInfo 失败: %v", err) + } + if userInfo.Email != "test@example.com" { + t.Errorf("Email 不匹配: got %s, want test@example.com", userInfo.Email) + } + if userInfo.Name != "Test User" { + t.Errorf("Name 不匹配: got %s, want Test User", userInfo.Name) + } + if userInfo.GivenName != "Test" { + t.Errorf("GivenName 不匹配: got %s, want Test", userInfo.GivenName) + } + if userInfo.FamilyName != "User" { + t.Errorf("FamilyName 不匹配: got %s, want User", userInfo.FamilyName) + } + if userInfo.Picture != "https://example.com/avatar.jpg" { + t.Errorf("Picture 不匹配: got %s", userInfo.Picture) + } +} + +func TestClient_GetUserInfo_Unauthorized_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + _, err := client.GetUserInfo(context.Background(), "bad-token") + if err == nil { + t.Fatal("服务器返回 401 时应返回错误") + } + if !strings.Contains(err.Error(), "获取用户信息失败") { + t.Errorf("错误信息应包含 '获取用户信息失败': got %s", err.Error()) + } + if !strings.Contains(err.Error(), "401") { + t.Errorf("错误信息应包含状态码 401: got %s", err.Error()) + } +} + +func TestClient_GetUserInfo_InvalidJSON_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{broken`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + _, err := client.GetUserInfo(context.Background(), "token") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "用户信息解析失败") { + t.Errorf("错误信息应包含 '用户信息解析失败': got %s", err.Error()) + } +} + +func TestClient_GetUserInfo_ContextCanceled_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := client.GetUserInfo(ctx, "token") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.LoadCodeAssist - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +// withMockBaseURLs 临时替换 BaseURLs,测试结束后恢复 +func withMockBaseURLs(t *testing.T, urls []string) { + t.Helper() + origBaseURLs := BaseURLs + origBaseURL := BaseURL + BaseURLs = urls + if len(urls) > 0 { + BaseURL = urls[0] + } + t.Cleanup(func() { + BaseURLs = origBaseURLs + BaseURL = origBaseURL + }) +} + +func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if !strings.HasSuffix(r.URL.Path, "/v1internal:loadCodeAssist") { + t.Errorf("URL 路径不匹配: got %s", r.URL.Path) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if ua := r.Header.Get("User-Agent"); ua != GetUserAgent() { + t.Errorf("User-Agent 不匹配: got %s", ua) + } + + // 验证请求体 + var reqBody LoadCodeAssistRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + t.Fatalf("解析请求体失败: %v", err) + } + if reqBody.Metadata.IDEType != "ANTIGRAVITY" { + t.Errorf("IDEType 不匹配: got %s, want ANTIGRAVITY", reqBody.Metadata.IDEType) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "cloudaicompanionProject": "test-project-123", + "currentTier": {"id": "free-tier", "name": "Free"}, + "paidTier": {"id": "g1-pro-tier", "name": "Pro", "description": "Pro plan"} + }`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token") + if err != nil { + t.Fatalf("LoadCodeAssist 失败: %v", err) + } + if resp.CloudAICompanionProject != "test-project-123" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } + if resp.GetTier() != "g1-pro-tier" { + t.Errorf("GetTier 不匹配: got %s, want g1-pro-tier", resp.GetTier()) + } + if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" { + t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier) + } + if resp.PaidTier == nil || resp.PaidTier.ID != "g1-pro-tier" { + t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier) + } + // 验证原始 JSON map + if rawResp == nil { + t.Fatal("rawResp 不应为 nil") + } + if rawResp["cloudaicompanionProject"] != "test-project-123" { + t.Errorf("rawResp cloudaicompanionProject 不匹配: got %v", rawResp["cloudaicompanionProject"]) + } +} + +func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"forbidden"}`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + _, _, err := client.LoadCodeAssist(context.Background(), "bad-token") + if err == nil { + t.Fatal("服务器返回 403 时应返回错误") + } + if !strings.Contains(err.Error(), "loadCodeAssist 失败") { + t.Errorf("错误信息应包含 'loadCodeAssist 失败': got %s", err.Error()) + } + if !strings.Contains(err.Error(), "403") { + t.Errorf("错误信息应包含状态码 403: got %s", err.Error()) + } +} + +func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{not valid json!!!`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + _, _, err := client.LoadCodeAssist(context.Background(), "token") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "响应解析失败") { + t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error()) + } +} + +func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) { + // 第一个 server 返回 500,第二个 server 返回成功 + callCount := 0 + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"internal"}`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "cloudaicompanionProject": "fallback-project", + "currentTier": {"id": "free-tier", "name": "Free"} + }`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := mustNewClient(t, "") + resp, _, err := client.LoadCodeAssist(context.Background(), "token") + if err != nil { + t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err) + } + if resp.CloudAICompanionProject != "fallback-project" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } + if callCount != 2 { + t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount) + } +} + +func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte(`{"error":"unavailable"}`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte(`{"error":"bad_gateway"}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := mustNewClient(t, "") + _, _, err := client.LoadCodeAssist(context.Background(), "token") + if err == nil { + t.Fatal("所有 URL 都失败时应返回错误") + } +} + +func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, _, err := client.LoadCodeAssist(ctx, "token") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.FetchAvailableModels - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if !strings.HasSuffix(r.URL.Path, "/v1internal:fetchAvailableModels") { + t.Errorf("URL 路径不匹配: got %s", r.URL.Path) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if ua := r.Header.Get("User-Agent"); ua != GetUserAgent() { + t.Errorf("User-Agent 不匹配: got %s", ua) + } + + // 验证请求体 + var reqBody FetchAvailableModelsRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + t.Fatalf("解析请求体失败: %v", err) + } + if reqBody.Project != "project-abc" { + t.Errorf("Project 不匹配: got %s, want project-abc", reqBody.Project) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "models": { + "gemini-2.0-flash": { + "quotaInfo": { + "remainingFraction": 0.85, + "resetTime": "2025-01-01T00:00:00Z" + } + }, + "gemini-2.5-pro": { + "quotaInfo": { + "remainingFraction": 0.5 + } + } + } + }`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc") + if err != nil { + t.Fatalf("FetchAvailableModels 失败: %v", err) + } + if resp.Models == nil { + t.Fatal("Models 不应为 nil") + } + if len(resp.Models) != 2 { + t.Errorf("Models 数量不匹配: got %d, want 2", len(resp.Models)) + } + + flashModel, ok := resp.Models["gemini-2.0-flash"] + if !ok { + t.Fatal("缺少 gemini-2.0-flash 模型") + } + if flashModel.QuotaInfo == nil { + t.Fatal("gemini-2.0-flash QuotaInfo 不应为 nil") + } + if flashModel.QuotaInfo.RemainingFraction != 0.85 { + t.Errorf("RemainingFraction 不匹配: got %f, want 0.85", flashModel.QuotaInfo.RemainingFraction) + } + if flashModel.QuotaInfo.ResetTime != "2025-01-01T00:00:00Z" { + t.Errorf("ResetTime 不匹配: got %s", flashModel.QuotaInfo.ResetTime) + } + + proModel, ok := resp.Models["gemini-2.5-pro"] + if !ok { + t.Fatal("缺少 gemini-2.5-pro 模型") + } + if proModel.QuotaInfo == nil { + t.Fatal("gemini-2.5-pro QuotaInfo 不应为 nil") + } + if proModel.QuotaInfo.RemainingFraction != 0.5 { + t.Errorf("RemainingFraction 不匹配: got %f, want 0.5", proModel.QuotaInfo.RemainingFraction) + } + + // 验证原始 JSON map + if rawResp == nil { + t.Fatal("rawResp 不应为 nil") + } + if rawResp["models"] == nil { + t.Error("rawResp models 不应为 nil") + } +} + +func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"forbidden"}`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + _, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj") + if err == nil { + t.Fatal("服务器返回 403 时应返回错误") + } + if !strings.Contains(err.Error(), "fetchAvailableModels 失败") { + t.Errorf("错误信息应包含 'fetchAvailableModels 失败': got %s", err.Error()) + } +} + +func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`<<>>`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "响应解析失败") { + t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error()) + } +} + +func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) { + callCount := 0 + // 第一个 server 返回 429,第二个 server 返回成功 + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":"rate_limited"}`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models": {"model-a": {}}}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := mustNewClient(t, "") + resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err != nil { + t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err) + } + if _, ok := resp.Models["model-a"]; !ok { + t.Error("应返回 fallback server 的模型") + } + if callCount != 2 { + t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount) + } +} + +func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`not found`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`internal error`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := mustNewClient(t, "") + _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err == nil { + t.Fatal("所有 URL 都失败时应返回错误") + } +} + +func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, _, err := client.FetchAvailableModels(ctx, "token", "proj") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models": {}}`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err != nil { + t.Fatalf("FetchAvailableModels 失败: %v", err) + } + if resp.Models == nil { + t.Fatal("Models 不应为 nil") + } + if len(resp.Models) != 0 { + t.Errorf("Models 应为空: got %d", len(resp.Models)) + } + if rawResp == nil { + t.Fatal("rawResp 不应为 nil") + } +} + +// --------------------------------------------------------------------------- +// LoadCodeAssist 和 FetchAvailableModels 的 408 fallback 测试 +// --------------------------------------------------------------------------- + +func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusRequestTimeout) + _, _ = w.Write([]byte(`timeout`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"cloudaicompanionProject":"p2","currentTier":"free-tier"}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := mustNewClient(t, "") + resp, _, err := client.LoadCodeAssist(context.Background(), "token") + if err != nil { + t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err) + } + if resp.CloudAICompanionProject != "p2" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } +} + +func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`not found`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models":{"m1":{"quotaInfo":{"remainingFraction":1.0}}}}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := mustNewClient(t, "") + resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err != nil { + t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err) + } + if _, ok := resp.Models["m1"]; !ok { + t.Error("应返回 fallback server 的模型 m1") + } +} + +func TestExtractProjectIDFromOnboardResponse(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + resp map[string]any + want string + }{ + { + name: "nil response", + resp: nil, + want: "", + }, + { + name: "empty response", + resp: map[string]any{}, + want: "", + }, + { + name: "project as string", + resp: map[string]any{ + "cloudaicompanionProject": "my-project-123", + }, + want: "my-project-123", + }, + { + name: "project as string with spaces", + resp: map[string]any{ + "cloudaicompanionProject": " my-project-123 ", + }, + want: "my-project-123", + }, + { + name: "project as map with id", + resp: map[string]any{ + "cloudaicompanionProject": map[string]any{ + "id": "proj-from-map", + }, + }, + want: "proj-from-map", + }, + { + name: "project as map without id", + resp: map[string]any{ + "cloudaicompanionProject": map[string]any{ + "name": "some-name", + }, + }, + want: "", + }, + { + name: "missing cloudaicompanionProject key", + resp: map[string]any{ + "otherField": "value", + }, + want: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := extractProjectIDFromOnboardResponse(tc.resp) + if got != tc.want { + t.Fatalf("extractProjectIDFromOnboardResponse() = %q, want %q", got, tc.want) + } + }) + } +} diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index c1cc998c..0ff24a1f 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -70,7 +70,7 @@ type GeminiGenerationConfig struct { ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"` } -// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持) +// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持) type GeminiImageConfig struct { AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4" ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K" @@ -155,6 +155,7 @@ type GeminiUsageMetadata struct { CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` TotalTokenCount int `json:"totalTokenCount,omitempty"` + ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费) } // GeminiGroundingMetadata Gemini grounding 元数据(Web Search) diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index d1712c98..18310655 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -6,10 +6,14 @@ import ( "encoding/base64" "encoding/hex" "fmt" + "net/http" "net/url" + "os" "strings" "sync" "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) const ( @@ -19,8 +23,10 @@ const ( UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo" // Antigravity OAuth 客户端凭证 - ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" - ClientSecret = "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" + ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" + + // AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。 + AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET" // 固定的 redirect_uri(用户需手动复制 code) RedirectURI = "http://localhost:8085/callback" @@ -32,9 +38,6 @@ const ( "https://www.googleapis.com/auth/cclog " + "https://www.googleapis.com/auth/experimentsandconfigs" - // User-Agent(与 Antigravity-Manager 保持一致) - UserAgent = "antigravity/1.15.8 windows/amd64" - // Session 过期时间 SessionTTL = 30 * time.Minute @@ -46,6 +49,36 @@ const ( antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" ) +// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6 +var defaultUserAgentVersion = "1.19.6" + +// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 +// 默认值使用占位符,生产环境请通过环境变量注入真实值。 +var defaultClientSecret = "GOCSPX-your-client-secret" + +func init() { + // 从环境变量读取版本号,未设置则使用默认值 + if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" { + defaultUserAgentVersion = version + } + // 从环境变量读取 client_secret,未设置则使用默认值 + if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" { + defaultClientSecret = secret + } +} + +// GetUserAgent 返回当前配置的 User-Agent +func GetUserAgent() string { + return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion) +} + +func getClientSecret() (string, error) { + if v := strings.TrimSpace(defaultClientSecret); v != "" { + return v, nil + } + return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv) +} + // BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致) var BaseURLs = []string{ antigravityProdBaseURL, // prod (优先) diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go new file mode 100644 index 00000000..2a2a52e9 --- /dev/null +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -0,0 +1,718 @@ +//go:build unit + +package antigravity + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "net/url" + "os" + "strings" + "testing" + "time" +) + +// --------------------------------------------------------------------------- +// getClientSecret +// --------------------------------------------------------------------------- + +func TestGetClientSecret_环境变量设置(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) + t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value") + + // 需要重新触发 init 逻辑:手动从环境变量读取 + defaultClientSecret = os.Getenv(AntigravityOAuthClientSecretEnv) + + secret, err := getClientSecret() + if err != nil { + t.Fatalf("获取 client_secret 失败: %v", err) + } + if secret != "my-secret-value" { + t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret) + } +} + +func TestGetClientSecret_环境变量为空(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) + + _, err := getClientSecret() + if err == nil { + t.Fatal("defaultClientSecret 为空时应返回错误") + } + if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) { + t.Errorf("错误信息应包含环境变量名: got %s", err.Error()) + } +} + +func TestGetClientSecret_环境变量未设置(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) + + _, err := getClientSecret() + if err == nil { + t.Fatal("defaultClientSecret 为空时应返回错误") + } +} + +func TestGetClientSecret_环境变量含空格(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = " " + t.Cleanup(func() { defaultClientSecret = old }) + + _, err := getClientSecret() + if err == nil { + t.Fatal("defaultClientSecret 仅含空格时应返回错误") + } +} + +func TestGetClientSecret_环境变量有前后空格(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = " valid-secret " + t.Cleanup(func() { defaultClientSecret = old }) + + secret, err := getClientSecret() + if err != nil { + t.Fatalf("获取 client_secret 失败: %v", err) + } + if secret != "valid-secret" { + t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret") + } +} + +// --------------------------------------------------------------------------- +// ForwardBaseURLs +// --------------------------------------------------------------------------- + +func TestForwardBaseURLs_Daily优先(t *testing.T) { + urls := ForwardBaseURLs() + if len(urls) == 0 { + t.Fatal("ForwardBaseURLs 返回空列表") + } + + // daily URL 应排在第一位 + if urls[0] != antigravityDailyBaseURL { + t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL) + } + + // 应包含所有 URL + if len(urls) != len(BaseURLs) { + t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs)) + } + + // 验证 prod URL 也在列表中 + found := false + for _, u := range urls { + if u == antigravityProdBaseURL { + found = true + break + } + } + if !found { + t.Error("ForwardBaseURLs 中缺少 prod URL") + } +} + +func TestForwardBaseURLs_不修改原切片(t *testing.T) { + originalFirst := BaseURLs[0] + _ = ForwardBaseURLs() + // 确保原始 BaseURLs 未被修改 + if BaseURLs[0] != originalFirst { + t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst) + } +} + +// --------------------------------------------------------------------------- +// URLAvailability +// --------------------------------------------------------------------------- + +func TestNewURLAvailability(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + if ua == nil { + t.Fatal("NewURLAvailability 返回 nil") + } + if ua.ttl != 5*time.Minute { + t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl) + } + if ua.unavailable == nil { + t.Error("unavailable map 不应为 nil") + } +} + +func TestURLAvailability_MarkUnavailable(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + testURL := "https://example.com" + + ua.MarkUnavailable(testURL) + + if ua.IsAvailable(testURL) { + t.Error("标记为不可用后 IsAvailable 应返回 false") + } +} + +func TestURLAvailability_MarkSuccess(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + testURL := "https://example.com" + + // 先标记为不可用 + ua.MarkUnavailable(testURL) + if ua.IsAvailable(testURL) { + t.Error("标记为不可用后应不可用") + } + + // 标记成功后应恢复可用 + ua.MarkSuccess(testURL) + if !ua.IsAvailable(testURL) { + t.Error("MarkSuccess 后应恢复可用") + } + + // 验证 lastSuccess 被设置 + ua.mu.RLock() + if ua.lastSuccess != testURL { + t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL) + } + ua.mu.RUnlock() +} + +func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) { + // 使用极短的 TTL + ua := NewURLAvailability(1 * time.Millisecond) + testURL := "https://example.com" + + ua.MarkUnavailable(testURL) + // 等待 TTL 过期 + time.Sleep(5 * time.Millisecond) + + if !ua.IsAvailable(testURL) { + t.Error("TTL 过期后 URL 应恢复可用") + } +} + +func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + if !ua.IsAvailable("https://never-marked.com") { + t.Error("未标记的 URL 应默认可用") + } +} + +func TestURLAvailability_GetAvailableURLs(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + + // 默认所有 URL 都可用 + urls := ua.GetAvailableURLs() + if len(urls) != len(BaseURLs) { + t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs)) + } +} + +func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + + if len(BaseURLs) < 2 { + t.Skip("BaseURLs 少于 2 个,跳过此测试") + } + + ua.MarkUnavailable(BaseURLs[0]) + urls := ua.GetAvailableURLs() + + // 标记的 URL 不应出现在可用列表中 + for _, u := range urls { + if u == BaseURLs[0] { + t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0]) + } + } +} + +func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com", "https://c.com"} + + urls := ua.GetAvailableURLsWithBase(customURLs) + if len(urls) != 3 { + t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls)) + } +} + +func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com", "https://c.com"} + + ua.MarkSuccess("https://c.com") + + urls := ua.GetAvailableURLsWithBase(customURLs) + if len(urls) != 3 { + t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls)) + } + // c.com 应排在第一位 + if urls[0] != "https://c.com" { + t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0]) + } + // 其余按原始顺序 + if urls[1] != "https://a.com" { + t.Errorf("第二个应为 a.com: got %s", urls[1]) + } + if urls[2] != "https://b.com" { + t.Errorf("第三个应为 b.com: got %s", urls[2]) + } +} + +func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com"} + + ua.MarkSuccess("https://b.com") + ua.MarkUnavailable("https://b.com") + + urls := ua.GetAvailableURLsWithBase(customURLs) + // b.com 被标记不可用,不应出现 + if len(urls) != 1 { + t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls)) + } + if urls[0] != "https://a.com" { + t.Errorf("仅 a.com 应可用: got %s", urls[0]) + } +} + +func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com"} + + ua.MarkSuccess("https://not-in-list.com") + + urls := ua.GetAvailableURLsWithBase(customURLs) + // lastSuccess 不在自定义列表中,不应被添加 + if len(urls) != 2 { + t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls)) + } +} + +// --------------------------------------------------------------------------- +// SessionStore +// --------------------------------------------------------------------------- + +func TestNewSessionStore(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + if store == nil { + t.Fatal("NewSessionStore 返回 nil") + } + if store.sessions == nil { + t.Error("sessions map 不应为 nil") + } +} + +func TestSessionStore_SetAndGet(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "test-state", + CodeVerifier: "test-verifier", + ProxyURL: "http://proxy.example.com", + CreatedAt: time.Now(), + } + + store.Set("session-1", session) + + got, ok := store.Get("session-1") + if !ok { + t.Fatal("Get 应返回 true") + } + if got.State != "test-state" { + t.Errorf("State 不匹配: got %s", got.State) + } + if got.CodeVerifier != "test-verifier" { + t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier) + } + if got.ProxyURL != "http://proxy.example.com" { + t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL) + } +} + +func TestSessionStore_Get_不存在(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + _, ok := store.Get("nonexistent") + if ok { + t.Error("不存在的 session 应返回 false") + } +} + +func TestSessionStore_Get_过期(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "expired-state", + CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期 + } + + store.Set("expired-session", session) + + _, ok := store.Get("expired-session") + if ok { + t.Error("过期的 session 应返回 false") + } +} + +func TestSessionStore_Delete(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "to-delete", + CreatedAt: time.Now(), + } + + store.Set("del-session", session) + store.Delete("del-session") + + _, ok := store.Get("del-session") + if ok { + t.Error("删除后 Get 应返回 false") + } +} + +func TestSessionStore_Delete_不存在(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + // 删除不存在的 session 不应 panic + store.Delete("nonexistent") +} + +func TestSessionStore_Stop(t *testing.T) { + store := NewSessionStore() + store.Stop() + + // 多次 Stop 不应 panic + store.Stop() +} + +func TestSessionStore_多个Session(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + for i := 0; i < 10; i++ { + session := &OAuthSession{ + State: "state-" + string(rune('0'+i)), + CreatedAt: time.Now(), + } + store.Set("session-"+string(rune('0'+i)), session) + } + + // 验证都能取到 + for i := 0; i < 10; i++ { + _, ok := store.Get("session-" + string(rune('0'+i))) + if !ok { + t.Errorf("session-%d 应存在", i) + } + } +} + +// --------------------------------------------------------------------------- +// GenerateRandomBytes +// --------------------------------------------------------------------------- + +func TestGenerateRandomBytes_长度正确(t *testing.T) { + sizes := []int{0, 1, 16, 32, 64, 128} + for _, size := range sizes { + b, err := GenerateRandomBytes(size) + if err != nil { + t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err) + } + if len(b) != size { + t.Errorf("长度不匹配: got %d, want %d", len(b), size) + } + } +} + +func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) { + b1, err := GenerateRandomBytes(32) + if err != nil { + t.Fatalf("第一次调用失败: %v", err) + } + b2, err := GenerateRandomBytes(32) + if err != nil { + t.Fatalf("第二次调用失败: %v", err) + } + // 两次生成的随机字节应该不同(概率上几乎不可能相同) + if string(b1) == string(b2) { + t.Error("两次生成的随机字节相同,概率极低,可能有问题") + } +} + +// --------------------------------------------------------------------------- +// GenerateState +// --------------------------------------------------------------------------- + +func TestGenerateState_返回值格式(t *testing.T) { + state, err := GenerateState() + if err != nil { + t.Fatalf("GenerateState 失败: %v", err) + } + if state == "" { + t.Error("GenerateState 返回空字符串") + } + // base64url 编码不应包含 +, /, = + if strings.ContainsAny(state, "+/=") { + t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state) + } + // 32 字节的 base64url 编码长度应为 43(去掉了尾部 = 填充) + if len(state) != 43 { + t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state)) + } +} + +func TestGenerateState_唯一性(t *testing.T) { + s1, _ := GenerateState() + s2, _ := GenerateState() + if s1 == s2 { + t.Error("两次 GenerateState 结果相同") + } +} + +// --------------------------------------------------------------------------- +// GenerateSessionID +// --------------------------------------------------------------------------- + +func TestGenerateSessionID_返回值格式(t *testing.T) { + id, err := GenerateSessionID() + if err != nil { + t.Fatalf("GenerateSessionID 失败: %v", err) + } + if id == "" { + t.Error("GenerateSessionID 返回空字符串") + } + // 16 字节的 hex 编码长度应为 32 + if len(id) != 32 { + t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id)) + } + // 验证是合法的 hex 字符串 + if _, err := hex.DecodeString(id); err != nil { + t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err) + } +} + +func TestGenerateSessionID_唯一性(t *testing.T) { + id1, _ := GenerateSessionID() + id2, _ := GenerateSessionID() + if id1 == id2 { + t.Error("两次 GenerateSessionID 结果相同") + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeVerifier +// --------------------------------------------------------------------------- + +func TestGenerateCodeVerifier_返回值格式(t *testing.T) { + verifier, err := GenerateCodeVerifier() + if err != nil { + t.Fatalf("GenerateCodeVerifier 失败: %v", err) + } + if verifier == "" { + t.Error("GenerateCodeVerifier 返回空字符串") + } + // base64url 编码不应包含 +, /, = + if strings.ContainsAny(verifier, "+/=") { + t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier) + } + // 32 字节的 base64url 编码长度应为 43 + if len(verifier) != 43 { + t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier)) + } +} + +func TestGenerateCodeVerifier_唯一性(t *testing.T) { + v1, _ := GenerateCodeVerifier() + v2, _ := GenerateCodeVerifier() + if v1 == v2 { + t.Error("两次 GenerateCodeVerifier 结果相同") + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeChallenge +// --------------------------------------------------------------------------- + +func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) { + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + + challenge := GenerateCodeChallenge(verifier) + + // 手动计算预期值 + hash := sha256.Sum256([]byte(verifier)) + expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=") + + if challenge != expected { + t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected) + } +} + +func TestGenerateCodeChallenge_不含填充字符(t *testing.T) { + challenge := GenerateCodeChallenge("test-verifier") + if strings.Contains(challenge, "=") { + t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge) + } +} + +func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) { + challenge := GenerateCodeChallenge("another-verifier") + if strings.ContainsAny(challenge, "+/") { + t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge) + } +} + +func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) { + c1 := GenerateCodeChallenge("same-verifier") + c2 := GenerateCodeChallenge("same-verifier") + if c1 != c2 { + t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2) + } +} + +func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) { + c1 := GenerateCodeChallenge("verifier-1") + c2 := GenerateCodeChallenge("verifier-2") + if c1 == c2 { + t.Error("不同输入应产生不同输出") + } +} + +// --------------------------------------------------------------------------- +// BuildAuthorizationURL +// --------------------------------------------------------------------------- + +func TestBuildAuthorizationURL_参数验证(t *testing.T) { + state := "test-state-123" + codeChallenge := "test-challenge-abc" + + authURL := BuildAuthorizationURL(state, codeChallenge) + + // 验证以 AuthorizeURL 开头 + if !strings.HasPrefix(authURL, AuthorizeURL+"?") { + t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL) + } + + // 解析 URL 并验证参数 + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("解析 URL 失败: %v", err) + } + + params := parsed.Query() + + expectedParams := map[string]string{ + "client_id": ClientID, + "redirect_uri": RedirectURI, + "response_type": "code", + "scope": Scopes, + "state": state, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + "access_type": "offline", + "prompt": "consent", + "include_granted_scopes": "true", + } + + for key, want := range expectedParams { + got := params.Get(key) + if got != want { + t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want) + } + } +} + +func TestBuildAuthorizationURL_参数数量(t *testing.T) { + authURL := BuildAuthorizationURL("s", "c") + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("解析 URL 失败: %v", err) + } + + params := parsed.Query() + // 应包含 10 个参数 + expectedCount := 10 + if len(params) != expectedCount { + t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount) + } +} + +func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) { + state := "state+with/special=chars" + codeChallenge := "challenge+value" + + authURL := BuildAuthorizationURL(state, codeChallenge) + + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("解析 URL 失败: %v", err) + } + + // 解析后应正确还原特殊字符 + if got := parsed.Query().Get("state"); got != state { + t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state) + } +} + +// --------------------------------------------------------------------------- +// 常量值验证 +// --------------------------------------------------------------------------- + +func TestConstants_值正确(t *testing.T) { + if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" { + t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL) + } + if TokenURL != "https://oauth2.googleapis.com/token" { + t.Errorf("TokenURL 不匹配: got %s", TokenURL) + } + if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" { + t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL) + } + if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" { + t.Errorf("ClientID 不匹配: got %s", ClientID) + } + secret, err := getClientSecret() + if err != nil { + t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err) + } + if secret != "GOCSPX-your-client-secret" { + t.Errorf("默认 client_secret 不匹配: got %s", secret) + } + if RedirectURI != "http://localhost:8085/callback" { + t.Errorf("RedirectURI 不匹配: got %s", RedirectURI) + } + if GetUserAgent() != "antigravity/1.19.6 windows/amd64" { + t.Errorf("UserAgent 不匹配: got %s", GetUserAgent()) + } + if SessionTTL != 30*time.Minute { + t.Errorf("SessionTTL 不匹配: got %v", SessionTTL) + } + if URLAvailabilityTTL != 5*time.Minute { + t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL) + } +} + +func TestScopes_包含必要范围(t *testing.T) { + expectedScopes := []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/cclog", + "https://www.googleapis.com/auth/experimentsandconfigs", + } + + for _, scope := range expectedScopes { + if !strings.Contains(Scopes, scope) { + t.Errorf("Scopes 缺少 %s", scope) + } + } +} diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 65f45cfc..55cdd786 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -64,6 +64,10 @@ const MaxTokensBudgetPadding = 1000 // Gemini 2.5 Flash thinking budget 上限 const Gemini25FlashThinkingBudgetLimit = 24576 +// 对于 Antigravity 的 Claude(budget-only)模型,该语义最终等价为 thinkingBudget=24576。 +// 这里复用相同数值以保持行为一致。 +const ClaudeAdaptiveHighThinkingBudgetTokens = Gemini25FlashThinkingBudgetLimit + // ensureMaxTokensGreaterThanBudget 确保 max_tokens > budget_tokens // Claude API 要求启用 thinking 时,max_tokens 必须大于 thinking.budget_tokens // 返回调整后的 maxTokens 和是否进行了调整 @@ -96,7 +100,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map } // 检测是否启用 thinking - isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + isThinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") // 只有 Gemini 模型支持 dummy thought workaround // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures @@ -198,11 +202,11 @@ type modelInfo struct { // modelInfoMap 模型前缀 → 模型信息映射 // 只有在此映射表中的模型才会注入身份提示词 -// 注意:当前 claude-opus-4-6 会被映射到 claude-opus-4-5-thinking, -// 但保留此条目以便后续 Antigravity 上游支持 4.6 时快速切换 +// 注意:模型映射逻辑在网关层完成;这里仅用于按模型前缀判断是否注入身份提示词。 var modelInfoMap = map[string]modelInfo{ "claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"}, "claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"}, + "claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"}, "claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"}, "claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"}, } @@ -271,6 +275,21 @@ func filterOpenCodePrompt(text string) string { return "" } +// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表 +var systemBlockFilterPrefixes = []string{ + "x-anthropic-billing-header", +} + +// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串 +func filterSystemBlockByPrefix(text string) string { + for _, prefix := range systemBlockFilterPrefixes { + if strings.HasPrefix(text, prefix) { + return "" + } + } + return text +} + // buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致) func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent { var parts []GeminiPart @@ -287,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans if strings.Contains(sysStr, "You are Antigravity") { userHasAntigravityIdentity = true } - // 过滤 OpenCode 默认提示词 - filtered := filterOpenCodePrompt(sysStr) + // 过滤 OpenCode 默认提示词和黑名单前缀 + filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr)) if filtered != "" { userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) } @@ -302,8 +321,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans if strings.Contains(block.Text, "You are Antigravity") { userHasAntigravityIdentity = true } - // 过滤 OpenCode 默认提示词 - filtered := filterOpenCodePrompt(block.Text) + // 过滤 OpenCode 默认提示词和黑名单前缀 + filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text)) if filtered != "" { userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) } @@ -578,6 +597,10 @@ func maxOutputTokensLimit(model string) int { return maxOutputTokensUpperBound } +func isAntigravityOpus46Model(model string) bool { + return strings.HasPrefix(strings.ToLower(model), "claude-opus-4-6") +} + func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { maxLimit := maxOutputTokensLimit(req.Model) config := &GeminiGenerationConfig{ @@ -591,25 +614,36 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { } // Thinking 配置 - if req.Thinking != nil && req.Thinking.Type == "enabled" { + if req.Thinking != nil && (req.Thinking.Type == "enabled" || req.Thinking.Type == "adaptive") { config.ThinkingConfig = &GeminiThinkingConfig{ IncludeThoughts: true, } + + // - thinking.type=enabled:budget_tokens>0 用显式预算 + // - thinking.type=adaptive:仅在 Antigravity 的 Opus 4.6 上覆写为 (24576) + budget := -1 if req.Thinking.BudgetTokens > 0 { - budget := req.Thinking.BudgetTokens + budget = req.Thinking.BudgetTokens + } + if req.Thinking.Type == "adaptive" && isAntigravityOpus46Model(req.Model) { + budget = ClaudeAdaptiveHighThinkingBudgetTokens + } + + // 正预算需要做上限与 max_tokens 约束;动态预算(-1)直接透传给上游。 + if budget > 0 { // gemini-2.5-flash 上限 if strings.Contains(req.Model, "gemini-2.5-flash") && budget > Gemini25FlashThinkingBudgetLimit { budget = Gemini25FlashThinkingBudgetLimit } - config.ThinkingConfig.ThinkingBudget = budget - // 自动修正:max_tokens 必须大于 budget_tokens + // 自动修正:max_tokens 必须大于 budget_tokens(Claude 上游要求) if adjusted, ok := ensureMaxTokensGreaterThanBudget(config.MaxOutputTokens, budget); ok { log.Printf("[Antigravity] Auto-adjusted max_tokens from %d to %d (must be > budget_tokens=%d)", config.MaxOutputTokens, adjusted, budget) config.MaxOutputTokens = adjusted } } + config.ThinkingConfig.ThinkingBudget = budget } if config.MaxOutputTokens > maxLimit { diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index f938b47f..f267e0e1 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -259,3 +259,93 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { }) } } + +func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) { + tests := []struct { + name string + model string + thinking *ThinkingConfig + wantBudget int + wantPresent bool + }{ + { + name: "enabled without budget defaults to dynamic (-1)", + model: "claude-opus-4-6-thinking", + thinking: &ThinkingConfig{Type: "enabled"}, + wantBudget: -1, + wantPresent: true, + }, + { + name: "enabled with budget uses the provided value", + model: "claude-opus-4-6-thinking", + thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1024}, + wantBudget: 1024, + wantPresent: true, + }, + { + name: "enabled with -1 budget uses dynamic (-1)", + model: "claude-opus-4-6-thinking", + thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: -1}, + wantBudget: -1, + wantPresent: true, + }, + { + name: "adaptive on opus4.6 maps to high budget (24576)", + model: "claude-opus-4-6-thinking", + thinking: &ThinkingConfig{Type: "adaptive", BudgetTokens: 20000}, + wantBudget: ClaudeAdaptiveHighThinkingBudgetTokens, + wantPresent: true, + }, + { + name: "adaptive on non-opus model keeps default dynamic (-1)", + model: "claude-sonnet-4-5-thinking", + thinking: &ThinkingConfig{Type: "adaptive"}, + wantBudget: -1, + wantPresent: true, + }, + { + name: "disabled does not emit thinkingConfig", + model: "claude-opus-4-6-thinking", + thinking: &ThinkingConfig{Type: "disabled", BudgetTokens: 1024}, + wantBudget: 0, + wantPresent: false, + }, + { + name: "nil thinking does not emit thinkingConfig", + model: "claude-opus-4-6-thinking", + thinking: nil, + wantBudget: 0, + wantPresent: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &ClaudeRequest{ + Model: tt.model, + Thinking: tt.thinking, + } + cfg := buildGenerationConfig(req) + if cfg == nil { + t.Fatalf("expected non-nil generationConfig") + } + + if tt.wantPresent { + if cfg.ThinkingConfig == nil { + t.Fatalf("expected thinkingConfig to be present") + } + if !cfg.ThinkingConfig.IncludeThoughts { + t.Fatalf("expected includeThoughts=true") + } + if cfg.ThinkingConfig.ThinkingBudget != tt.wantBudget { + t.Fatalf("expected thinkingBudget=%d, got %d", tt.wantBudget, cfg.ThinkingConfig.ThinkingBudget) + } + return + } + + if cfg.ThinkingConfig != nil { + t.Fatalf("expected thinkingConfig to be nil, got %+v", cfg.ThinkingConfig) + } + }) + } +} diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index eb16f09d..f12effb6 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -1,10 +1,13 @@ package antigravity import ( + "crypto/rand" "encoding/json" "fmt" "log" "strings" + "sync/atomic" + "time" ) // TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式) @@ -279,7 +282,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon if geminiResp.UsageMetadata != nil { cached := geminiResp.UsageMetadata.CachedContentTokenCount usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached - usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount usage.CacheReadInputTokens = cached } @@ -341,12 +344,30 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string { return builder.String() } -// generateRandomID 生成随机 ID +// fallbackCounter 降级伪随机 ID 的全局计数器,混入 seed 避免高并发下 UnixNano 相同导致碰撞。 +var fallbackCounter uint64 + +// generateRandomID 生成密码学安全的随机 ID func generateRandomID() string { const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - result := make([]byte, 12) - for i := range result { - result[i] = chars[i%len(chars)] + id := make([]byte, 12) + randBytes := make([]byte, 12) + if _, err := rand.Read(randBytes); err != nil { + // 避免在请求路径里 panic:极端情况下熵源不可用时降级为伪随机。 + // 这里主要用于生成响应/工具调用的临时 ID,安全要求不高但需尽量避免碰撞。 + cnt := atomic.AddUint64(&fallbackCounter, 1) + seed := uint64(time.Now().UnixNano()) ^ cnt + seed ^= uint64(len(err.Error())) << 32 + for i := range id { + seed ^= seed << 13 + seed ^= seed >> 7 + seed ^= seed << 17 + id[i] = chars[int(seed)%len(chars)] + } + return string(id) } - return string(result) + for i, b := range randBytes { + id[i] = chars[int(b)%len(chars)] + } + return string(id) } diff --git a/backend/internal/pkg/antigravity/response_transformer_test.go b/backend/internal/pkg/antigravity/response_transformer_test.go new file mode 100644 index 00000000..da402b17 --- /dev/null +++ b/backend/internal/pkg/antigravity/response_transformer_test.go @@ -0,0 +1,109 @@ +//go:build unit + +package antigravity + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Task 7: 验证 generateRandomID 和降级碰撞防护 --- + +func TestGenerateRandomID_Uniqueness(t *testing.T) { + seen := make(map[string]struct{}, 100) + for i := 0; i < 100; i++ { + id := generateRandomID() + require.Len(t, id, 12, "ID 长度应为 12") + _, dup := seen[id] + require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id) + seen[id] = struct{}{} + } +} + +func TestFallbackCounter_Increments(t *testing.T) { + // 验证 fallbackCounter 的原子递增行为确保降级分支不会生成相同 seed + before := atomic.LoadUint64(&fallbackCounter) + cnt1 := atomic.AddUint64(&fallbackCounter, 1) + cnt2 := atomic.AddUint64(&fallbackCounter, 1) + require.Equal(t, before+1, cnt1, "第一次递增应为 before+1") + require.Equal(t, before+2, cnt2, "第二次递增应为 before+2") + require.NotEqual(t, cnt1, cnt2, "连续两次递增的计数器值应不同") +} + +func TestFallbackCounter_ConcurrentIncrements(t *testing.T) { + // 验证并发递增的原子性 — 每次递增都应产生唯一值 + const goroutines = 50 + results := make([]uint64, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + results[idx] = atomic.AddUint64(&fallbackCounter, 1) + }(i) + } + wg.Wait() + + // 所有结果应唯一 + seen := make(map[uint64]bool, goroutines) + for _, v := range results { + assert.False(t, seen[v], "并发递增产生了重复值: %d", v) + seen[v] = true + } +} + +func TestGenerateRandomID_Charset(t *testing.T) { + const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + validSet := make(map[byte]struct{}, len(validChars)) + for i := 0; i < len(validChars); i++ { + validSet[validChars[i]] = struct{}{} + } + + for i := 0; i < 50; i++ { + id := generateRandomID() + for j := 0; j < len(id); j++ { + _, ok := validSet[id[j]] + require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id) + } + } +} + +func TestGenerateRandomID_Length(t *testing.T) { + for i := 0; i < 100; i++ { + id := generateRandomID() + assert.Len(t, id, 12, "每次生成的 ID 长度应为 12") + } +} + +func TestGenerateRandomID_ConcurrentUniqueness(t *testing.T) { + // 验证并发调用不会产生重复 ID + const goroutines = 100 + results := make([]string, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + results[idx] = generateRandomID() + }(i) + } + wg.Wait() + + seen := make(map[string]bool, goroutines) + for _, id := range results { + assert.False(t, seen[id], "并发调用产生了重复 ID: %s", id) + seen[id] = true + } +} + +func BenchmarkGenerateRandomID(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = generateRandomID() + } +} diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index b384658a..677435ad 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -85,7 +85,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { if geminiResp.UsageMetadata != nil { cached := geminiResp.UsageMetadata.CachedContentTokenCount p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached - p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount p.cacheReadTokens = cached } @@ -146,7 +146,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte if v1Resp.Response.UsageMetadata != nil { cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached - usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount usage.CacheReadInputTokens = cached } diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index eecee11e..22405382 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -10,8 +10,14 @@ const ( BetaInterleavedThinking = "interleaved-thinking-2025-05-14" BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14" BetaTokenCounting = "token-counting-2024-11-01" + BetaContext1M = "context-1m-2025-08-07" + BetaFastMode = "fast-mode-2026-02-01" ) +// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。 +// 这些 token 是客户端特有的,不应透传给上游 API。 +var DroppedBetas = []string{BetaContext1M, BetaFastMode} + // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming @@ -77,6 +83,12 @@ var DefaultModels = []Model{ DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-06T00:00:00Z", }, + { + ID: "claude-sonnet-4-6", + Type: "model", + DisplayName: "Claude Sonnet 4.6", + CreatedAt: "2026-02-18T00:00:00Z", + }, { ID: "claude-sonnet-4-5-20250929", Type: "model", diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 9bf563e7..25782c55 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -8,9 +8,21 @@ const ( // ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置 ForcePlatform Key = "ctx_force_platform" + // RequestID 为服务端生成/透传的请求 ID。 + RequestID Key = "ctx_request_id" + // ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。 ClientRequestID Key = "ctx_client_request_id" + // Model 请求模型标识(用于统一请求链路日志字段)。 + Model Key = "ctx_model" + + // Platform 当前请求最终命中的平台(用于统一请求链路日志字段)。 + Platform Key = "ctx_platform" + + // AccountID 当前请求最终命中的账号 ID(用于统一请求链路日志字段)。 + AccountID Key = "ctx_account_id" + // RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。 RetryCount Key = "ctx_retry_count" @@ -28,4 +40,19 @@ const ( // IsMaxTokensOneHaikuRequest 标识当前请求是否为 max_tokens=1 + haiku 模型的探测请求 // 用于 ClaudeCodeOnly 验证绕过(绕过 system prompt 检查,但仍需验证 User-Agent) IsMaxTokensOneHaikuRequest Key = "ctx_is_max_tokens_one_haiku" + + // SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。 + // 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。 + SingleAccountRetry Key = "ctx_single_account_retry" + + // PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。 + // Service 层可复用该值,避免同请求链路重复读取 Redis。 + PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id" + + // PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。 + // Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。 + PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id" + + // ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22") + ClaudeCodeVersion Key = "ctx_claude_code_version" ) diff --git a/backend/internal/pkg/errors/errors_test.go b/backend/internal/pkg/errors/errors_test.go index 1a1c842e..25e62907 100644 --- a/backend/internal/pkg/errors/errors_test.go +++ b/backend/internal/pkg/errors/errors_test.go @@ -166,3 +166,18 @@ func TestToHTTP(t *testing.T) { }) } } + +func TestToHTTP_MetadataDeepCopy(t *testing.T) { + md := map[string]string{"k": "v"} + appErr := BadRequest("BAD_REQUEST", "invalid").WithMetadata(md) + + code, body := ToHTTP(appErr) + require.Equal(t, http.StatusBadRequest, code) + require.Equal(t, "v", body.Metadata["k"]) + + md["k"] = "changed" + require.Equal(t, "v", body.Metadata["k"]) + + appErr.Metadata["k"] = "changed-again" + require.Equal(t, "v", body.Metadata["k"]) +} diff --git a/backend/internal/pkg/errors/http.go b/backend/internal/pkg/errors/http.go index 7b5560e3..420c69a3 100644 --- a/backend/internal/pkg/errors/http.go +++ b/backend/internal/pkg/errors/http.go @@ -16,6 +16,16 @@ func ToHTTP(err error) (statusCode int, body Status) { return http.StatusOK, Status{Code: int32(http.StatusOK)} } - cloned := Clone(appErr) - return int(cloned.Code), cloned.Status + body = Status{ + Code: appErr.Code, + Reason: appErr.Reason, + Message: appErr.Message, + } + if appErr.Metadata != nil { + body.Metadata = make(map[string]string, len(appErr.Metadata)) + for k, v := range appErr.Metadata { + body.Metadata[k] = v + } + } + return int(appErr.Code), body } diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go index 424e8ddb..c300b17d 100644 --- a/backend/internal/pkg/gemini/models.go +++ b/backend/internal/pkg/gemini/models.go @@ -21,6 +21,7 @@ func DefaultModels() []Model { {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods}, } } diff --git a/backend/internal/pkg/geminicli/constants.go b/backend/internal/pkg/geminicli/constants.go index d4d52116..f5ee5735 100644 --- a/backend/internal/pkg/geminicli/constants.go +++ b/backend/internal/pkg/geminicli/constants.go @@ -39,7 +39,10 @@ const ( // They enable the "login without creating your own OAuth client" experience, but Google may // restrict which scopes are allowed for this client. GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - GeminiCLIOAuthClientSecret = "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" + GeminiCLIOAuthClientSecret = "GOCSPX-your-client-secret" + + // GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret. + GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET" SessionTTL = 30 * time.Minute diff --git a/backend/internal/pkg/geminicli/models.go b/backend/internal/pkg/geminicli/models.go index 08e69886..1fc4d983 100644 --- a/backend/internal/pkg/geminicli/models.go +++ b/backend/internal/pkg/geminicli/models.go @@ -16,6 +16,7 @@ var DefaultModels = []Model{ {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""}, {ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""}, {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""}, + {ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""}, } // DefaultTestModel is the default model to preselect in test flows. diff --git a/backend/internal/pkg/geminicli/oauth.go b/backend/internal/pkg/geminicli/oauth.go index c71e8aad..b10b5750 100644 --- a/backend/internal/pkg/geminicli/oauth.go +++ b/backend/internal/pkg/geminicli/oauth.go @@ -6,10 +6,14 @@ import ( "encoding/base64" "encoding/hex" "fmt" + "net/http" "net/url" + "os" "strings" "sync" "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) type OAuthConfig struct { @@ -164,15 +168,24 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error } // Fall back to built-in Gemini CLI OAuth client when not configured. + // SECURITY: This repo does not embed the built-in client secret; it must be provided via env. if effective.ClientID == "" && effective.ClientSecret == "" { + secret := strings.TrimSpace(GeminiCLIOAuthClientSecret) + if secret == "" { + if v, ok := os.LookupEnv(GeminiCLIOAuthClientSecretEnv); ok { + secret = strings.TrimSpace(v) + } + } + if secret == "" { + return OAuthConfig{}, infraerrors.Newf(http.StatusBadRequest, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING", "built-in Gemini CLI OAuth client_secret is not configured; set %s or provide a custom OAuth client", GeminiCLIOAuthClientSecretEnv) + } effective.ClientID = GeminiCLIOAuthClientID - effective.ClientSecret = GeminiCLIOAuthClientSecret + effective.ClientSecret = secret } else if effective.ClientID == "" || effective.ClientSecret == "" { - return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)") + return OAuthConfig{}, infraerrors.New(http.StatusBadRequest, "GEMINI_OAUTH_CLIENT_NOT_CONFIGURED", "OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)") } - isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID && - effective.ClientSecret == GeminiCLIOAuthClientSecret + isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID if effective.Scopes == "" { // Use different default scopes based on OAuth type diff --git a/backend/internal/pkg/geminicli/oauth_test.go b/backend/internal/pkg/geminicli/oauth_test.go index 0770730a..2a430f9e 100644 --- a/backend/internal/pkg/geminicli/oauth_test.go +++ b/backend/internal/pkg/geminicli/oauth_test.go @@ -1,11 +1,441 @@ package geminicli import ( + "encoding/hex" "strings" + "sync" "testing" + "time" ) +// --------------------------------------------------------------------------- +// SessionStore 测试 +// --------------------------------------------------------------------------- + +func TestSessionStore_SetAndGet(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "test-state", + OAuthType: "code_assist", + CreatedAt: time.Now(), + } + store.Set("sid-1", session) + + got, ok := store.Get("sid-1") + if !ok { + t.Fatal("期望 Get 返回 ok=true,实际返回 false") + } + if got.State != "test-state" { + t.Errorf("期望 State=%q,实际=%q", "test-state", got.State) + } +} + +func TestSessionStore_GetNotFound(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + _, ok := store.Get("不存在的ID") + if ok { + t.Error("期望不存在的 sessionID 返回 ok=false") + } +} + +func TestSessionStore_GetExpired(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + // 创建一个已过期的 session(CreatedAt 设置为 SessionTTL+1 分钟之前) + session := &OAuthSession{ + State: "expired-state", + OAuthType: "code_assist", + CreatedAt: time.Now().Add(-(SessionTTL + 1*time.Minute)), + } + store.Set("expired-sid", session) + + _, ok := store.Get("expired-sid") + if ok { + t.Error("期望过期的 session 返回 ok=false") + } +} + +func TestSessionStore_Delete(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "to-delete", + OAuthType: "code_assist", + CreatedAt: time.Now(), + } + store.Set("del-sid", session) + + // 先确认存在 + if _, ok := store.Get("del-sid"); !ok { + t.Fatal("删除前 session 应该存在") + } + + store.Delete("del-sid") + + if _, ok := store.Get("del-sid"); ok { + t.Error("删除后 session 不应该存在") + } +} + +func TestSessionStore_Stop_Idempotent(t *testing.T) { + store := NewSessionStore() + + // 多次调用 Stop 不应 panic + store.Stop() + store.Stop() + store.Stop() +} + +func TestSessionStore_ConcurrentAccess(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + const goroutines = 50 + var wg sync.WaitGroup + wg.Add(goroutines * 3) + + // 并发写入 + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + sid := "concurrent-" + string(rune('A'+idx%26)) + store.Set(sid, &OAuthSession{ + State: sid, + OAuthType: "code_assist", + CreatedAt: time.Now(), + }) + }(i) + } + + // 并发读取 + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + sid := "concurrent-" + string(rune('A'+idx%26)) + store.Get(sid) // 可能找到也可能没找到,关键是不 panic + }(i) + } + + // 并发删除 + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + sid := "concurrent-" + string(rune('A'+idx%26)) + store.Delete(sid) + }(i) + } + + wg.Wait() +} + +// --------------------------------------------------------------------------- +// GenerateRandomBytes 测试 +// --------------------------------------------------------------------------- + +func TestGenerateRandomBytes(t *testing.T) { + tests := []int{0, 1, 16, 32, 64} + for _, n := range tests { + b, err := GenerateRandomBytes(n) + if err != nil { + t.Errorf("GenerateRandomBytes(%d) 出错: %v", n, err) + continue + } + if len(b) != n { + t.Errorf("GenerateRandomBytes(%d) 返回长度=%d,期望=%d", n, len(b), n) + } + } +} + +func TestGenerateRandomBytes_Uniqueness(t *testing.T) { + // 两次调用应该返回不同的结果(极小概率相同,32字节足够) + a, _ := GenerateRandomBytes(32) + b, _ := GenerateRandomBytes(32) + if string(a) == string(b) { + t.Error("两次 GenerateRandomBytes(32) 返回了相同结果,随机性可能有问题") + } +} + +// --------------------------------------------------------------------------- +// GenerateState 测试 +// --------------------------------------------------------------------------- + +func TestGenerateState(t *testing.T) { + state, err := GenerateState() + if err != nil { + t.Fatalf("GenerateState() 出错: %v", err) + } + if state == "" { + t.Error("GenerateState() 返回空字符串") + } + // base64url 编码不应包含 padding '=' + if strings.Contains(state, "=") { + t.Errorf("GenerateState() 结果包含 '=' padding: %s", state) + } + // base64url 不应包含 '+' 或 '/' + if strings.ContainsAny(state, "+/") { + t.Errorf("GenerateState() 结果包含非 base64url 字符: %s", state) + } +} + +// --------------------------------------------------------------------------- +// GenerateSessionID 测试 +// --------------------------------------------------------------------------- + +func TestGenerateSessionID(t *testing.T) { + sid, err := GenerateSessionID() + if err != nil { + t.Fatalf("GenerateSessionID() 出错: %v", err) + } + // 16 字节 -> 32 个 hex 字符 + if len(sid) != 32 { + t.Errorf("GenerateSessionID() 长度=%d,期望=32", len(sid)) + } + // 必须是合法的 hex 字符串 + if _, err := hex.DecodeString(sid); err != nil { + t.Errorf("GenerateSessionID() 不是合法的 hex 字符串: %s, err=%v", sid, err) + } +} + +func TestGenerateSessionID_Uniqueness(t *testing.T) { + a, _ := GenerateSessionID() + b, _ := GenerateSessionID() + if a == b { + t.Error("两次 GenerateSessionID() 返回了相同结果") + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeVerifier 测试 +// --------------------------------------------------------------------------- + +func TestGenerateCodeVerifier(t *testing.T) { + verifier, err := GenerateCodeVerifier() + if err != nil { + t.Fatalf("GenerateCodeVerifier() 出错: %v", err) + } + if verifier == "" { + t.Error("GenerateCodeVerifier() 返回空字符串") + } + // RFC 7636 要求 code_verifier 至少 43 个字符 + if len(verifier) < 43 { + t.Errorf("GenerateCodeVerifier() 长度=%d,RFC 7636 要求至少 43 字符", len(verifier)) + } + // base64url 编码不应包含 padding 和非 URL 安全字符 + if strings.Contains(verifier, "=") { + t.Errorf("GenerateCodeVerifier() 包含 '=' padding: %s", verifier) + } + if strings.ContainsAny(verifier, "+/") { + t.Errorf("GenerateCodeVerifier() 包含非 base64url 字符: %s", verifier) + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeChallenge 测试 +// --------------------------------------------------------------------------- + +func TestGenerateCodeChallenge(t *testing.T) { + // 使用已知输入验证输出 + // RFC 7636 附录 B 示例: verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + // 预期 challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + expected := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + + challenge := GenerateCodeChallenge(verifier) + if challenge != expected { + t.Errorf("GenerateCodeChallenge(%q) = %q,期望 %q", verifier, challenge, expected) + } +} + +func TestGenerateCodeChallenge_NoPadding(t *testing.T) { + challenge := GenerateCodeChallenge("test-verifier-string") + if strings.Contains(challenge, "=") { + t.Errorf("GenerateCodeChallenge() 结果包含 '=' padding: %s", challenge) + } +} + +// --------------------------------------------------------------------------- +// base64URLEncode 测试 +// --------------------------------------------------------------------------- + +func TestBase64URLEncode(t *testing.T) { + tests := []struct { + name string + input []byte + }{ + {"空字节", []byte{}}, + {"单字节", []byte{0xff}}, + {"多字节", []byte{0x01, 0x02, 0x03, 0x04, 0x05}}, + {"全零", []byte{0x00, 0x00, 0x00}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := base64URLEncode(tt.input) + // 不应包含 '=' padding + if strings.Contains(result, "=") { + t.Errorf("base64URLEncode(%v) 包含 '=' padding: %s", tt.input, result) + } + // 不应包含标准 base64 的 '+' 或 '/' + if strings.ContainsAny(result, "+/") { + t.Errorf("base64URLEncode(%v) 包含非 URL 安全字符: %s", tt.input, result) + } + }) + } +} + +// --------------------------------------------------------------------------- +// hasRestrictedScope 测试 +// --------------------------------------------------------------------------- + +func TestHasRestrictedScope(t *testing.T) { + tests := []struct { + scope string + expected bool + }{ + // 受限 scope + {"https://www.googleapis.com/auth/generative-language", true}, + {"https://www.googleapis.com/auth/generative-language.retriever", true}, + {"https://www.googleapis.com/auth/generative-language.tuning", true}, + {"https://www.googleapis.com/auth/drive", true}, + {"https://www.googleapis.com/auth/drive.readonly", true}, + {"https://www.googleapis.com/auth/drive.file", true}, + // 非受限 scope + {"https://www.googleapis.com/auth/cloud-platform", false}, + {"https://www.googleapis.com/auth/userinfo.email", false}, + {"https://www.googleapis.com/auth/userinfo.profile", false}, + // 边界情况 + {"", false}, + {"random-scope", false}, + } + for _, tt := range tests { + t.Run(tt.scope, func(t *testing.T) { + got := hasRestrictedScope(tt.scope) + if got != tt.expected { + t.Errorf("hasRestrictedScope(%q) = %v,期望 %v", tt.scope, got, tt.expected) + } + }) + } +} + +// --------------------------------------------------------------------------- +// BuildAuthorizationURL 测试 +// --------------------------------------------------------------------------- + +func TestBuildAuthorizationURL(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret") + + authURL, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "https://example.com/callback", + "", + "code_assist", + ) + if err != nil { + t.Fatalf("BuildAuthorizationURL() 出错: %v", err) + } + + // 检查返回的 URL 包含期望的参数 + checks := []string{ + "response_type=code", + "client_id=" + GeminiCLIOAuthClientID, + "redirect_uri=", + "state=test-state", + "code_challenge=test-challenge", + "code_challenge_method=S256", + "access_type=offline", + "prompt=consent", + "include_granted_scopes=true", + } + for _, check := range checks { + if !strings.Contains(authURL, check) { + t.Errorf("BuildAuthorizationURL() URL 缺少参数 %q\nURL: %s", check, authURL) + } + } + + // 不应包含 project_id(因为传的是空字符串) + if strings.Contains(authURL, "project_id=") { + t.Errorf("BuildAuthorizationURL() 空 projectID 时不应包含 project_id 参数") + } + + // URL 应该以正确的授权端点开头 + if !strings.HasPrefix(authURL, AuthorizeURL+"?") { + t.Errorf("BuildAuthorizationURL() URL 应以 %s? 开头,实际: %s", AuthorizeURL, authURL) + } +} + +func TestBuildAuthorizationURL_EmptyRedirectURI(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret") + + _, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "", // 空 redirectURI + "", + "code_assist", + ) + if err == nil { + t.Error("BuildAuthorizationURL() 空 redirectURI 应该报错") + } + if !strings.Contains(err.Error(), "redirect_uri") { + t.Errorf("错误消息应包含 'redirect_uri',实际: %v", err) + } +} + +func TestBuildAuthorizationURL_WithProjectID(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret") + + authURL, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "https://example.com/callback", + "my-project-123", + "code_assist", + ) + if err != nil { + t.Fatalf("BuildAuthorizationURL() 出错: %v", err) + } + if !strings.Contains(authURL, "project_id=my-project-123") { + t.Errorf("BuildAuthorizationURL() 带 projectID 时应包含 project_id 参数\nURL: %s", authURL) + } +} + +func TestBuildAuthorizationURL_UsesBuiltinSecretFallback(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "") + + authURL, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "https://example.com/callback", + "", + "code_assist", + ) + if err != nil { + t.Fatalf("BuildAuthorizationURL() 不应报错: %v", err) + } + if !strings.Contains(authURL, "client_id="+GeminiCLIOAuthClientID) { + t.Errorf("应使用内置 Gemini CLI client_id,实际 URL: %s", authURL) + } +} + +// --------------------------------------------------------------------------- +// EffectiveOAuthConfig 测试 - 原有测试 +// --------------------------------------------------------------------------- + func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { + // 内置的 Gemini CLI client secret 不嵌入在此仓库中。 + // 测试通过环境变量设置一个假的 secret 来模拟运维配置。 + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + tests := []struct { name string input OAuthConfig @@ -15,7 +445,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr bool }{ { - name: "Google One with built-in client (empty config)", + name: "Google One 使用内置客户端(空配置)", input: OAuthConfig{}, oauthType: "google_one", wantClientID: GeminiCLIOAuthClientID, @@ -23,18 +453,18 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Google One always uses built-in client (even if custom credentials passed)", + name: "Google One 使用自定义客户端(传入自定义凭据时使用自定义)", input: OAuthConfig{ ClientID: "custom-client-id", ClientSecret: "custom-client-secret", }, oauthType: "google_one", wantClientID: "custom-client-id", - wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client + wantScopes: DefaultCodeAssistScopes, wantErr: false, }, { - name: "Google One with built-in client and custom scopes (should filter restricted scopes)", + name: "Google One 内置客户端 + 自定义 scopes(应过滤受限 scopes)", input: OAuthConfig{ Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", }, @@ -44,7 +474,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Google One with built-in client and only restricted scopes (should fallback to default)", + name: "Google One 内置客户端 + 仅受限 scopes(应回退到默认)", input: OAuthConfig{ Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", }, @@ -54,7 +484,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Code Assist with built-in client", + name: "Code Assist 使用内置客户端", input: OAuthConfig{}, oauthType: "code_assist", wantClientID: GeminiCLIOAuthClientID, @@ -84,7 +514,9 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { } func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) { - // Test that Google One with built-in client filters out restricted scopes + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 测试 Google One + 内置客户端过滤受限 scopes cfg, err := EffectiveOAuthConfig(OAuthConfig{ Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile", }, "google_one") @@ -93,21 +525,242 @@ func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) { t.Fatalf("EffectiveOAuthConfig() error = %v", err) } - // Should only contain cloud-platform, userinfo.email, and userinfo.profile - // Should NOT contain generative-language or drive scopes + // 应仅包含 cloud-platform、userinfo.email 和 userinfo.profile + // 不应包含 generative-language 或 drive scopes if strings.Contains(cfg.Scopes, "generative-language") { - t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes) + t.Errorf("使用内置客户端时 Scopes 不应包含 generative-language,实际: %v", cfg.Scopes) } if strings.Contains(cfg.Scopes, "drive") { - t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes) + t.Errorf("使用内置客户端时 Scopes 不应包含 drive,实际: %v", cfg.Scopes) } if !strings.Contains(cfg.Scopes, "cloud-platform") { - t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes) + t.Errorf("Scopes 应包含 cloud-platform,实际: %v", cfg.Scopes) } if !strings.Contains(cfg.Scopes, "userinfo.email") { - t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes) + t.Errorf("Scopes 应包含 userinfo.email,实际: %v", cfg.Scopes) } if !strings.Contains(cfg.Scopes, "userinfo.profile") { - t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes) + t.Errorf("Scopes 应包含 userinfo.profile,实际: %v", cfg.Scopes) + } +} + +// --------------------------------------------------------------------------- +// EffectiveOAuthConfig 测试 - 新增分支覆盖 +// --------------------------------------------------------------------------- + +func TestEffectiveOAuthConfig_OnlyClientID_NoSecret(t *testing.T) { + // 只提供 clientID 不提供 secret 应报错 + _, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "some-client-id", + }, "code_assist") + if err == nil { + t.Error("只提供 ClientID 不提供 ClientSecret 应该报错") + } + if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") { + t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err) + } +} + +func TestEffectiveOAuthConfig_OnlyClientSecret_NoID(t *testing.T) { + // 只提供 secret 不提供 clientID 应报错 + _, err := EffectiveOAuthConfig(OAuthConfig{ + ClientSecret: "some-client-secret", + }, "code_assist") + if err == nil { + t.Error("只提供 ClientSecret 不提供 ClientID 应该报错") + } + if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") { + t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err) + } +} + +func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_BuiltinClient(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // ai_studio 类型,使用内置客户端,scopes 为空 -> 应使用 DefaultCodeAssistScopes(因为内置客户端不能请求 generative-language scope) + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultCodeAssistScopes { + t.Errorf("ai_studio + 内置客户端应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_CustomClient(t *testing.T) { + // ai_studio 类型,使用自定义客户端,scopes 为空 -> 应使用 DefaultAIStudioScopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + }, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultAIStudioScopes { + t.Errorf("ai_studio + 自定义客户端应使用 DefaultAIStudioScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_AIStudio_ScopeNormalization(t *testing.T) { + // ai_studio 类型,旧的 generative-language scope 应被归一化为 generative-language.retriever + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/generative-language https://www.googleapis.com/auth/cloud-platform", + }, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if strings.Contains(cfg.Scopes, "auth/generative-language ") || strings.HasSuffix(cfg.Scopes, "auth/generative-language") { + // 确保不包含未归一化的旧 scope(仅 generative-language 而非 generative-language.retriever) + parts := strings.Fields(cfg.Scopes) + for _, p := range parts { + if p == "https://www.googleapis.com/auth/generative-language" { + t.Errorf("ai_studio 应将 generative-language 归一化为 generative-language.retriever,实际 scopes: %q", cfg.Scopes) + } + } + } + if !strings.Contains(cfg.Scopes, "generative-language.retriever") { + t.Errorf("ai_studio 归一化后应包含 generative-language.retriever,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_CommaSeparatedScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 逗号分隔的 scopes 应被归一化为空格分隔 + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/userinfo.email", + }, "code_assist") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + // 应该用空格分隔,而非逗号 + if strings.Contains(cfg.Scopes, ",") { + t.Errorf("逗号分隔的 scopes 应被归一化为空格分隔,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "cloud-platform") { + t.Errorf("归一化后应包含 cloud-platform,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "userinfo.email") { + t.Errorf("归一化后应包含 userinfo.email,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_MixedCommaAndSpaceScopes(t *testing.T) { + // 混合逗号和空格分隔的 scopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/cloud-platform, https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile", + }, "code_assist") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + parts := strings.Fields(cfg.Scopes) + if len(parts) != 3 { + t.Errorf("归一化后应有 3 个 scope,实际: %d,scopes: %q", len(parts), cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) { + // 输入中的前后空白应被清理 + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: " custom-id ", + ClientSecret: " custom-secret ", + Scopes: " https://www.googleapis.com/auth/cloud-platform ", + }, "code_assist") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.ClientID != "custom-id" { + t.Errorf("ClientID 应去除前后空白,实际: %q", cfg.ClientID) + } + if cfg.ClientSecret != "custom-secret" { + t.Errorf("ClientSecret 应去除前后空白,实际: %q", cfg.ClientSecret) + } + if cfg.Scopes != "https://www.googleapis.com/auth/cloud-platform" { + t.Errorf("Scopes 应去除前后空白,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "") + + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist") + if err != nil { + t.Fatalf("不设置环境变量时应回退到内置 secret,实际报错: %v", err) + } + if strings.TrimSpace(cfg.ClientSecret) == "" { + t.Error("ClientSecret 不应为空") + } + if cfg.ClientID != GeminiCLIOAuthClientID { + t.Errorf("ClientID 应回退为内置客户端 ID,实际: %q", cfg.ClientID) + } +} + +func TestEffectiveOAuthConfig_AIStudio_BuiltinClient_CustomScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // ai_studio + 内置客户端 + 自定义 scopes -> 应过滤受限 scopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever", + }, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + // 内置客户端应过滤 generative-language.retriever + if strings.Contains(cfg.Scopes, "generative-language") { + t.Errorf("ai_studio + 内置客户端应过滤受限 scopes,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "cloud-platform") { + t.Errorf("应保留 cloud-platform scope,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_UnknownOAuthType_DefaultScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 未知的 oauthType 应回退到默认的 code_assist scopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "unknown_type") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultCodeAssistScopes { + t.Errorf("未知 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_EmptyOAuthType_DefaultScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 空的 oauthType 应走 default 分支,使用 DefaultCodeAssistScopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultCodeAssistScopes { + t.Errorf("空 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_CustomClient_NoScopeFiltering(t *testing.T) { + // 自定义客户端 + google_one + 包含受限 scopes -> 不应被过滤(因为不是内置客户端) + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", + }, "google_one") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + // 自定义客户端不应过滤任何 scope + if !strings.Contains(cfg.Scopes, "generative-language.retriever") { + t.Errorf("自定义客户端不应过滤 generative-language.retriever,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "drive.readonly") { + t.Errorf("自定义客户端不应过滤 drive.readonly,实际: %q", cfg.Scopes) } } diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index 76b7aa91..32e4bc5b 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -18,11 +18,11 @@ package httpclient import ( "fmt" "net/http" - "net/url" "strings" "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) @@ -32,6 +32,7 @@ const ( defaultMaxIdleConns = 100 // 最大空闲连接数 defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时) + validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL ) // Options 定义共享 HTTP 客户端的构建参数 @@ -40,7 +41,6 @@ type Options struct { Timeout time.Duration // 请求总超时时间 ResponseHeaderTimeout time.Duration // 等待响应头超时时间 InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true) - ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding) AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用) @@ -53,6 +53,9 @@ type Options struct { // sharedClients 存储按配置参数缓存的 http.Client 实例 var sharedClients sync.Map +// 允许测试替换校验函数,生产默认指向真实实现。 +var validateResolvedIP = urlvalidator.ValidateResolvedIP + // GetClient 返回共享的 HTTP 客户端实例 // 性能优化:相同配置复用同一客户端,避免重复创建 Transport // 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险 @@ -84,7 +87,7 @@ func buildClient(opts Options) (*http.Client, error) { var rt http.RoundTripper = transport if opts.ValidateResolvedIP && !opts.AllowPrivateHosts { - rt = &validatedTransport{base: transport} + rt = newValidatedTransport(transport) } return &http.Client{ Transport: rt, @@ -116,15 +119,13 @@ func buildTransport(opts Options) (*http.Transport, error) { return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead") } - proxyURL := strings.TrimSpace(opts.ProxyURL) - if proxyURL == "" { - return transport, nil - } - - parsed, err := url.Parse(proxyURL) + _, parsed, err := proxyurl.Parse(opts.ProxyURL) if err != nil { return nil, err } + if parsed == nil { + return transport, nil + } if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { return nil, err @@ -134,12 +135,11 @@ func buildTransport(opts Options) (*http.Transport, error) { } func buildClientKey(opts Options) string { - return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d", + return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%d|%d|%d", strings.TrimSpace(opts.ProxyURL), opts.Timeout.String(), opts.ResponseHeaderTimeout.String(), opts.InsecureSkipVerify, - opts.ProxyStrict, opts.ValidateResolvedIP, opts.AllowPrivateHosts, opts.MaxIdleConns, @@ -149,17 +149,56 @@ func buildClientKey(opts Options) string { } type validatedTransport struct { - base http.RoundTripper + base http.RoundTripper + validatedHosts sync.Map // map[string]time.Time, value 为过期时间 + now func() time.Time +} + +func newValidatedTransport(base http.RoundTripper) *validatedTransport { + return &validatedTransport{ + base: base, + now: time.Now, + } +} + +func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool { + if t == nil { + return false + } + raw, ok := t.validatedHosts.Load(host) + if !ok { + return false + } + expireAt, ok := raw.(time.Time) + if !ok { + t.validatedHosts.Delete(host) + return false + } + if now.Before(expireAt) { + return true + } + t.validatedHosts.Delete(host) + return false } func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) { if req != nil && req.URL != nil { - host := strings.TrimSpace(req.URL.Hostname()) + host := strings.ToLower(strings.TrimSpace(req.URL.Hostname())) if host != "" { - if err := urlvalidator.ValidateResolvedIP(host); err != nil { - return nil, err + now := time.Now() + if t != nil && t.now != nil { + now = t.now() + } + if !t.isValidatedHost(host, now) { + if err := validateResolvedIP(host); err != nil { + return nil, err + } + t.validatedHosts.Store(host, now.Add(validatedHostTTL)) } } } + if t == nil || t.base == nil { + return nil, fmt.Errorf("validated transport base is nil") + } return t.base.RoundTrip(req) } diff --git a/backend/internal/pkg/httpclient/pool_test.go b/backend/internal/pkg/httpclient/pool_test.go new file mode 100644 index 00000000..f945758a --- /dev/null +++ b/backend/internal/pkg/httpclient/pool_test.go @@ -0,0 +1,115 @@ +package httpclient + +import ( + "errors" + "io" + "net/http" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestValidatedTransport_CacheHostValidation(t *testing.T) { + originalValidate := validateResolvedIP + defer func() { validateResolvedIP = originalValidate }() + + var validateCalls int32 + validateResolvedIP = func(host string) error { + atomic.AddInt32(&validateCalls, 1) + require.Equal(t, "api.openai.com", host) + return nil + } + + var baseCalls int32 + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + atomic.AddInt32(&baseCalls, 1) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{}`)), + Header: make(http.Header), + }, nil + }) + + now := time.Unix(1730000000, 0) + transport := newValidatedTransport(base) + transport.now = func() time.Time { return now } + + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls)) + require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls)) +} + +func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) { + originalValidate := validateResolvedIP + defer func() { validateResolvedIP = originalValidate }() + + var validateCalls int32 + validateResolvedIP = func(_ string) error { + atomic.AddInt32(&validateCalls, 1) + return nil + } + + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{}`)), + Header: make(http.Header), + }, nil + }) + + now := time.Unix(1730001000, 0) + transport := newValidatedTransport(base) + transport.now = func() time.Time { return now } + + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + now = now.Add(validatedHostTTL + time.Second) + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls)) +} + +func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) { + originalValidate := validateResolvedIP + defer func() { validateResolvedIP = originalValidate }() + + expectedErr := errors.New("dns rebinding rejected") + validateResolvedIP = func(_ string) error { + return expectedErr + } + + var baseCalls int32 + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + atomic.AddInt32(&baseCalls, 1) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil + }) + + transport := newValidatedTransport(base) + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.ErrorIs(t, err, expectedErr) + require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls)) +} diff --git a/backend/internal/pkg/httputil/body.go b/backend/internal/pkg/httputil/body.go new file mode 100644 index 00000000..69e99dc5 --- /dev/null +++ b/backend/internal/pkg/httputil/body.go @@ -0,0 +1,37 @@ +package httputil + +import ( + "bytes" + "io" + "net/http" +) + +const ( + requestBodyReadInitCap = 512 + requestBodyReadMaxInitCap = 1 << 20 +) + +// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length. +func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) { + if req == nil || req.Body == nil { + return nil, nil + } + + capHint := requestBodyReadInitCap + if req.ContentLength > 0 { + switch { + case req.ContentLength < int64(requestBodyReadInitCap): + capHint = requestBodyReadInitCap + case req.ContentLength > int64(requestBodyReadMaxInitCap): + capHint = requestBodyReadMaxInitCap + default: + capHint = int(req.ContentLength) + } + } + + buf := bytes.NewBuffer(make([]byte, 0, capHint)) + if _, err := io.Copy(buf, req.Body); err != nil { + return nil, err + } + return buf.Bytes(), nil +} diff --git a/backend/internal/pkg/ip/ip.go b/backend/internal/pkg/ip/ip.go index 97109c0c..f6f77c86 100644 --- a/backend/internal/pkg/ip/ip.go +++ b/backend/internal/pkg/ip/ip.go @@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string { return normalizeIP(c.ClientIP()) } +// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。 +// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。 +// 适用于 ACL / 风控等安全敏感场景。 +func GetTrustedClientIP(c *gin.Context) string { + if c == nil { + return "" + } + return normalizeIP(c.ClientIP()) +} + // normalizeIP 规范化 IP 地址,去除端口号和空格。 func normalizeIP(ip string) string { ip = strings.TrimSpace(ip) @@ -54,29 +64,89 @@ func normalizeIP(ip string) string { return ip } -// isPrivateIP 检查 IP 是否为私有地址。 -func isPrivateIP(ipStr string) bool { - ip := net.ParseIP(ipStr) - if ip == nil { - return false - } +// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析 +var privateNets []*net.IPNet - // 私有 IP 范围 - privateBlocks := []string{ +// CompiledIPRules 表示预编译的 IP 匹配规则。 +// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。 +type CompiledIPRules struct { + CIDRs []*net.IPNet + IPs []net.IP + PatternCount int +} + +func init() { + for _, cidr := range []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.0/8", "::1/128", "fc00::/7", - } - - for _, block := range privateBlocks { - _, cidr, err := net.ParseCIDR(block) + } { + _, block, err := net.ParseCIDR(cidr) if err != nil { + panic("invalid CIDR: " + cidr) + } + privateNets = append(privateNets, block) + } +} + +// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。 +// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。 +func CompileIPRules(patterns []string) *CompiledIPRules { + compiled := &CompiledIPRules{ + CIDRs: make([]*net.IPNet, 0, len(patterns)), + IPs: make([]net.IP, 0, len(patterns)), + PatternCount: len(patterns), + } + for _, pattern := range patterns { + normalized := strings.TrimSpace(pattern) + if normalized == "" { continue } - if cidr.Contains(ip) { + if strings.Contains(normalized, "/") { + _, cidr, err := net.ParseCIDR(normalized) + if err != nil || cidr == nil { + continue + } + compiled.CIDRs = append(compiled.CIDRs, cidr) + continue + } + parsedIP := net.ParseIP(normalized) + if parsedIP == nil { + continue + } + compiled.IPs = append(compiled.IPs, parsedIP) + } + return compiled +} + +func matchesCompiledRules(parsedIP net.IP, rules *CompiledIPRules) bool { + if parsedIP == nil || rules == nil { + return false + } + for _, cidr := range rules.CIDRs { + if cidr.Contains(parsedIP) { + return true + } + } + for _, ruleIP := range rules.IPs { + if parsedIP.Equal(ruleIP) { + return true + } + } + return false +} + +// isPrivateIP 检查 IP 是否为私有地址。 +func isPrivateIP(ipStr string) bool { + ip := net.ParseIP(ipStr) + if ip == nil { + return false + } + for _, block := range privateNets { + if block.Contains(ip) { return true } } @@ -127,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool { // 2. 如果白名单不为空,IP 必须在白名单中 // 3. 如果白名单为空,允许访问(除非被黑名单拒绝) func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) { + return CheckIPRestrictionWithCompiledRules( + clientIP, + CompileIPRules(whitelist), + CompileIPRules(blacklist), + ) +} + +// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。 +func CheckIPRestrictionWithCompiledRules(clientIP string, whitelist, blacklist *CompiledIPRules) (bool, string) { // 规范化 IP clientIP = normalizeIP(clientIP) if clientIP == "" { return false, "access denied" } + parsedIP := net.ParseIP(clientIP) + if parsedIP == nil { + return false, "access denied" + } // 1. 检查黑名单 - if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) { + if blacklist != nil && blacklist.PatternCount > 0 && matchesCompiledRules(parsedIP, blacklist) { return false, "access denied" } // 2. 检查白名单(如果设置了白名单,IP 必须在其中) - if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) { + if whitelist != nil && whitelist.PatternCount > 0 && !matchesCompiledRules(parsedIP, whitelist) { return false, "access denied" } diff --git a/backend/internal/pkg/ip/ip_test.go b/backend/internal/pkg/ip/ip_test.go new file mode 100644 index 00000000..403b2d59 --- /dev/null +++ b/backend/internal/pkg/ip/ip_test.go @@ -0,0 +1,96 @@ +//go:build unit + +package ip + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + name string + ip string + expected bool + }{ + // 私有 IPv4 + {"10.x 私有地址", "10.0.0.1", true}, + {"10.x 私有地址段末", "10.255.255.255", true}, + {"172.16.x 私有地址", "172.16.0.1", true}, + {"172.31.x 私有地址", "172.31.255.255", true}, + {"192.168.x 私有地址", "192.168.1.1", true}, + {"127.0.0.1 本地回环", "127.0.0.1", true}, + {"127.x 回环段", "127.255.255.255", true}, + + // 公网 IPv4 + {"8.8.8.8 公网 DNS", "8.8.8.8", false}, + {"1.1.1.1 公网", "1.1.1.1", false}, + {"172.15.255.255 非私有", "172.15.255.255", false}, + {"172.32.0.0 非私有", "172.32.0.0", false}, + {"11.0.0.1 公网", "11.0.0.1", false}, + + // IPv6 + {"::1 IPv6 回环", "::1", true}, + {"fc00:: IPv6 私有", "fc00::1", true}, + {"fd00:: IPv6 私有", "fd00::1", true}, + {"2001:db8::1 IPv6 公网", "2001:db8::1", false}, + + // 无效输入 + {"空字符串", "", false}, + {"非法字符串", "not-an-ip", false}, + {"不完整 IP", "192.168", false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isPrivateIP(tc.ip) + require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip) + }) + } +} + +func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + require.NoError(t, r.SetTrustedProxies(nil)) + + r.GET("/t", func(c *gin.Context) { + c.String(200, GetTrustedClientIP(c)) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/t", nil) + req.RemoteAddr = "9.9.9.9:12345" + req.Header.Set("X-Forwarded-For", "1.2.3.4") + req.Header.Set("X-Real-IP", "1.2.3.4") + req.Header.Set("CF-Connecting-IP", "1.2.3.4") + r.ServeHTTP(w, req) + + require.Equal(t, 200, w.Code) + require.Equal(t, "9.9.9.9", w.Body.String()) +} + +func TestCheckIPRestrictionWithCompiledRules(t *testing.T) { + whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"}) + blacklist := CompileIPRules([]string{"10.1.1.1"}) + + allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist) + require.True(t, allowed) + require.Equal(t, "", reason) + + allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist) + require.False(t, allowed) + require.Equal(t, "access denied", reason) +} + +func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) { + // 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。 + invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"}) + allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil) + require.False(t, allowed) + require.Equal(t, "access denied", reason) +} diff --git a/backend/internal/pkg/logger/config_adapter.go b/backend/internal/pkg/logger/config_adapter.go new file mode 100644 index 00000000..c34e448b --- /dev/null +++ b/backend/internal/pkg/logger/config_adapter.go @@ -0,0 +1,31 @@ +package logger + +import "github.com/Wei-Shaw/sub2api/internal/config" + +func OptionsFromConfig(cfg config.LogConfig) InitOptions { + return InitOptions{ + Level: cfg.Level, + Format: cfg.Format, + ServiceName: cfg.ServiceName, + Environment: cfg.Environment, + Caller: cfg.Caller, + StacktraceLevel: cfg.StacktraceLevel, + Output: OutputOptions{ + ToStdout: cfg.Output.ToStdout, + ToFile: cfg.Output.ToFile, + FilePath: cfg.Output.FilePath, + }, + Rotation: RotationOptions{ + MaxSizeMB: cfg.Rotation.MaxSizeMB, + MaxBackups: cfg.Rotation.MaxBackups, + MaxAgeDays: cfg.Rotation.MaxAgeDays, + Compress: cfg.Rotation.Compress, + LocalTime: cfg.Rotation.LocalTime, + }, + Sampling: SamplingOptions{ + Enabled: cfg.Sampling.Enabled, + Initial: cfg.Sampling.Initial, + Thereafter: cfg.Sampling.Thereafter, + }, + } +} diff --git a/backend/internal/pkg/logger/logger.go b/backend/internal/pkg/logger/logger.go new file mode 100644 index 00000000..3fca706e --- /dev/null +++ b/backend/internal/pkg/logger/logger.go @@ -0,0 +1,530 @@ +package logger + +import ( + "context" + "fmt" + "io" + "log" + "log/slog" + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "gopkg.in/natefinch/lumberjack.v2" +) + +type Level = zapcore.Level + +const ( + LevelDebug = zapcore.DebugLevel + LevelInfo = zapcore.InfoLevel + LevelWarn = zapcore.WarnLevel + LevelError = zapcore.ErrorLevel + LevelFatal = zapcore.FatalLevel +) + +type Sink interface { + WriteLogEvent(event *LogEvent) +} + +type LogEvent struct { + Time time.Time + Level string + Component string + Message string + LoggerName string + Fields map[string]any +} + +var ( + mu sync.RWMutex + global atomic.Pointer[zap.Logger] + sugar atomic.Pointer[zap.SugaredLogger] + atomicLevel zap.AtomicLevel + initOptions InitOptions + currentSink atomic.Value // sinkState + stdLogUndo func() + bootstrapOnce sync.Once +) + +type sinkState struct { + sink Sink +} + +func InitBootstrap() { + bootstrapOnce.Do(func() { + if err := Init(bootstrapOptions()); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "logger bootstrap init failed: %v\n", err) + } + }) +} + +func Init(options InitOptions) error { + mu.Lock() + defer mu.Unlock() + return initLocked(options) +} + +func initLocked(options InitOptions) error { + normalized := options.normalized() + zl, al, err := buildLogger(normalized) + if err != nil { + return err + } + + prev := global.Load() + global.Store(zl) + sugar.Store(zl.Sugar()) + atomicLevel = al + initOptions = normalized + + bridgeSlogLocked() + bridgeStdLogLocked() + + if prev != nil { + _ = prev.Sync() + } + return nil +} + +func Reconfigure(mutator func(*InitOptions) error) error { + mu.Lock() + defer mu.Unlock() + next := initOptions + if mutator != nil { + if err := mutator(&next); err != nil { + return err + } + } + return initLocked(next) +} + +func SetLevel(level string) error { + lv, ok := parseLevel(level) + if !ok { + return fmt.Errorf("invalid log level: %s", level) + } + + mu.Lock() + defer mu.Unlock() + atomicLevel.SetLevel(lv) + initOptions.Level = strings.ToLower(strings.TrimSpace(level)) + return nil +} + +func CurrentLevel() string { + mu.RLock() + defer mu.RUnlock() + if global.Load() == nil { + return "info" + } + return atomicLevel.Level().String() +} + +func SetSink(sink Sink) { + currentSink.Store(sinkState{sink: sink}) +} + +func loadSink() Sink { + v := currentSink.Load() + if v == nil { + return nil + } + state, ok := v.(sinkState) + if !ok { + return nil + } + return state.sink +} + +// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。 +// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。 +func WriteSinkEvent(level, component, message string, fields map[string]any) { + sink := loadSink() + if sink == nil { + return + } + + level = strings.ToLower(strings.TrimSpace(level)) + if level == "" { + level = "info" + } + component = strings.TrimSpace(component) + message = strings.TrimSpace(message) + if message == "" { + return + } + + eventFields := make(map[string]any, len(fields)+1) + for k, v := range fields { + eventFields[k] = v + } + if component != "" { + if _, ok := eventFields["component"]; !ok { + eventFields["component"] = component + } + } + + sink.WriteLogEvent(&LogEvent{ + Time: time.Now(), + Level: level, + Component: component, + Message: message, + LoggerName: component, + Fields: eventFields, + }) +} + +func L() *zap.Logger { + if l := global.Load(); l != nil { + return l + } + return zap.NewNop() +} + +func S() *zap.SugaredLogger { + if s := sugar.Load(); s != nil { + return s + } + return zap.NewNop().Sugar() +} + +func With(fields ...zap.Field) *zap.Logger { + return L().With(fields...) +} + +func Sync() { + l := global.Load() + if l != nil { + _ = l.Sync() + } +} + +func bridgeStdLogLocked() { + if stdLogUndo != nil { + stdLogUndo() + stdLogUndo = nil + } + + prevFlags := log.Flags() + prevPrefix := log.Prefix() + prevWriter := log.Writer() + + log.SetFlags(0) + log.SetPrefix("") + base := global.Load() + if base == nil { + base = zap.NewNop() + } + log.SetOutput(newStdLogBridge(base.Named("stdlog"))) + + stdLogUndo = func() { + log.SetOutput(prevWriter) + log.SetFlags(prevFlags) + log.SetPrefix(prevPrefix) + } +} + +func bridgeSlogLocked() { + base := global.Load() + if base == nil { + base = zap.NewNop() + } + slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog")))) +} + +func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) { + level, _ := parseLevel(options.Level) + atomic := zap.NewAtomicLevelAt(level) + + encoderCfg := zapcore.EncoderConfig{ + TimeKey: "time", + LevelKey: "level", + NameKey: "logger", + CallerKey: "caller", + MessageKey: "msg", + StacktraceKey: "stacktrace", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeDuration: zapcore.MillisDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + + var enc zapcore.Encoder + if options.Format == "console" { + enc = zapcore.NewConsoleEncoder(encoderCfg) + } else { + enc = zapcore.NewJSONEncoder(encoderCfg) + } + + sinkCore := newSinkCore() + cores := make([]zapcore.Core, 0, 3) + + if options.Output.ToStdout { + infoPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + return lvl >= atomic.Level() && lvl < zapcore.WarnLevel + }) + errPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + return lvl >= atomic.Level() && lvl >= zapcore.WarnLevel + }) + cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), infoPriority)) + cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stderr), errPriority)) + } + + if options.Output.ToFile { + fileCore, filePath, fileErr := buildFileCore(enc, atomic, options) + if fileErr != nil { + _, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"日志文件输出初始化失败,降级为仅标准输出\" path=%s err=%v\n", + time.Now().Format(time.RFC3339Nano), + filePath, + fileErr, + ) + } else { + cores = append(cores, fileCore) + } + } + + if len(cores) == 0 { + cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), atomic)) + } + + core := zapcore.NewTee(cores...) + if options.Sampling.Enabled { + core = zapcore.NewSamplerWithOptions(core, samplingTick(), options.Sampling.Initial, options.Sampling.Thereafter) + } + core = sinkCore.Wrap(core) + + stacktraceLevel, _ := parseStacktraceLevel(options.StacktraceLevel) + zapOpts := make([]zap.Option, 0, 5) + if options.Caller { + zapOpts = append(zapOpts, zap.AddCaller()) + } + if stacktraceLevel <= zapcore.FatalLevel { + zapOpts = append(zapOpts, zap.AddStacktrace(stacktraceLevel)) + } + + logger := zap.New(core, zapOpts...).With( + zap.String("service", options.ServiceName), + zap.String("env", options.Environment), + ) + return logger, atomic, nil +} + +func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOptions) (zapcore.Core, string, error) { + filePath := options.Output.FilePath + if strings.TrimSpace(filePath) == "" { + filePath = resolveLogFilePath("") + } + + dir := filepath.Dir(filePath) + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, filePath, err + } + lj := &lumberjack.Logger{ + Filename: filePath, + MaxSize: options.Rotation.MaxSizeMB, + MaxBackups: options.Rotation.MaxBackups, + MaxAge: options.Rotation.MaxAgeDays, + Compress: options.Rotation.Compress, + LocalTime: options.Rotation.LocalTime, + } + return zapcore.NewCore(enc, zapcore.AddSync(lj), atomic), filePath, nil +} + +type sinkCore struct { + core zapcore.Core + fields []zapcore.Field +} + +func newSinkCore() *sinkCore { + return &sinkCore{} +} + +func (s *sinkCore) Wrap(core zapcore.Core) zapcore.Core { + cp := *s + cp.core = core + return &cp +} + +func (s *sinkCore) Enabled(level zapcore.Level) bool { + return s.core.Enabled(level) +} + +func (s *sinkCore) With(fields []zapcore.Field) zapcore.Core { + nextFields := append([]zapcore.Field{}, s.fields...) + nextFields = append(nextFields, fields...) + return &sinkCore{ + core: s.core.With(fields), + fields: nextFields, + } +} + +func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry { + // Delegate to inner core (tee) so each sub-core's level enabler is respected. + // Then add ourselves for sink forwarding only. + ce = s.core.Check(entry, ce) + if ce != nil { + ce = ce.AddCore(entry, s) + } + return ce +} + +func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { + // Only handle sink forwarding — the inner cores write via their own + // Write methods (added to CheckedEntry by s.core.Check above). + sink := loadSink() + if sink == nil { + return nil + } + + enc := zapcore.NewMapObjectEncoder() + for _, f := range s.fields { + f.AddTo(enc) + } + for _, f := range fields { + f.AddTo(enc) + } + + event := &LogEvent{ + Time: entry.Time, + Level: strings.ToLower(entry.Level.String()), + Component: entry.LoggerName, + Message: entry.Message, + LoggerName: entry.LoggerName, + Fields: enc.Fields, + } + sink.WriteLogEvent(event) + return nil +} + +func (s *sinkCore) Sync() error { + return s.core.Sync() +} + +type stdLogBridge struct { + logger *zap.Logger +} + +func newStdLogBridge(l *zap.Logger) io.Writer { + if l == nil { + l = zap.NewNop() + } + return &stdLogBridge{logger: l} +} + +func (b *stdLogBridge) Write(p []byte) (int, error) { + msg := normalizeStdLogMessage(string(p)) + if msg == "" { + return len(p), nil + } + + level := inferStdLogLevel(msg) + entry := b.logger.WithOptions(zap.AddCallerSkip(4)) + + switch level { + case LevelDebug: + entry.Debug(msg, zap.Bool("legacy_stdlog", true)) + case LevelWarn: + entry.Warn(msg, zap.Bool("legacy_stdlog", true)) + case LevelError, LevelFatal: + entry.Error(msg, zap.Bool("legacy_stdlog", true)) + default: + entry.Info(msg, zap.Bool("legacy_stdlog", true)) + } + return len(p), nil +} + +func normalizeStdLogMessage(raw string) string { + msg := strings.TrimSpace(strings.ReplaceAll(raw, "\n", " ")) + if msg == "" { + return "" + } + return strings.Join(strings.Fields(msg), " ") +} + +func inferStdLogLevel(msg string) Level { + lower := strings.ToLower(strings.TrimSpace(msg)) + if lower == "" { + return LevelInfo + } + + if strings.HasPrefix(lower, "[debug]") || strings.HasPrefix(lower, "debug:") { + return LevelDebug + } + if strings.HasPrefix(lower, "[warn]") || strings.HasPrefix(lower, "[warning]") || strings.HasPrefix(lower, "warn:") || strings.HasPrefix(lower, "warning:") { + return LevelWarn + } + if strings.HasPrefix(lower, "[error]") || strings.HasPrefix(lower, "error:") || strings.HasPrefix(lower, "fatal:") || strings.HasPrefix(lower, "panic:") { + return LevelError + } + + if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") { + return LevelError + } + if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") { + return LevelWarn + } + return LevelInfo +} + +// LegacyPrintf 用于平滑迁移历史的 printf 风格日志到结构化 logger。 +func LegacyPrintf(component, format string, args ...any) { + msg := normalizeStdLogMessage(fmt.Sprintf(format, args...)) + if msg == "" { + return + } + + initialized := global.Load() != nil + if !initialized { + // 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。 + log.Print(msg) + return + } + + l := L() + if component != "" { + l = l.With(zap.String("component", component)) + } + l = l.WithOptions(zap.AddCallerSkip(1)) + + switch inferStdLogLevel(msg) { + case LevelDebug: + l.Debug(msg, zap.Bool("legacy_printf", true)) + case LevelWarn: + l.Warn(msg, zap.Bool("legacy_printf", true)) + case LevelError, LevelFatal: + l.Error(msg, zap.Bool("legacy_printf", true)) + default: + l.Info(msg, zap.Bool("legacy_printf", true)) + } +} + +type contextKey string + +const loggerContextKey contextKey = "ctx_logger" + +func IntoContext(ctx context.Context, l *zap.Logger) context.Context { + if ctx == nil { + ctx = context.Background() + } + if l == nil { + l = L() + } + return context.WithValue(ctx, loggerContextKey, l) +} + +func FromContext(ctx context.Context) *zap.Logger { + if ctx == nil { + return L() + } + if l, ok := ctx.Value(loggerContextKey).(*zap.Logger); ok && l != nil { + return l + } + return L() +} diff --git a/backend/internal/pkg/logger/logger_test.go b/backend/internal/pkg/logger/logger_test.go new file mode 100644 index 00000000..74aae061 --- /dev/null +++ b/backend/internal/pkg/logger/logger_test.go @@ -0,0 +1,192 @@ +package logger + +import ( + "encoding/json" + "io" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestInit_DualOutput(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "logs", "sub2api.log") + + origStdout := os.Stdout + origStderr := os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutR.Close() + _ = stderrR.Close() + _ = stdoutW.Close() + _ = stderrW.Close() + }) + + err = Init(InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: OutputOptions{ + ToStdout: true, + ToFile: true, + FilePath: logPath, + }, + Rotation: RotationOptions{ + MaxSizeMB: 10, + MaxBackups: 2, + MaxAgeDays: 1, + }, + Sampling: SamplingOptions{Enabled: false}, + }) + if err != nil { + t.Fatalf("Init() error: %v", err) + } + + L().Info("dual-output-info") + L().Warn("dual-output-warn") + Sync() + + _ = stdoutW.Close() + _ = stderrW.Close() + stdoutBytes, _ := io.ReadAll(stdoutR) + stderrBytes, _ := io.ReadAll(stderrR) + stdoutText := string(stdoutBytes) + stderrText := string(stderrBytes) + + if !strings.Contains(stdoutText, "dual-output-info") { + t.Fatalf("stdout missing info log: %s", stdoutText) + } + if !strings.Contains(stderrText, "dual-output-warn") { + t.Fatalf("stderr missing warn log: %s", stderrText) + } + + fileBytes, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("read log file: %v", err) + } + fileText := string(fileBytes) + if !strings.Contains(fileText, "dual-output-info") || !strings.Contains(fileText, "dual-output-warn") { + t.Fatalf("file missing logs: %s", fileText) + } +} + +func TestInit_FileOutputFailureDowngrade(t *testing.T) { + origStdout := os.Stdout + origStderr := os.Stderr + _, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + }) + + err = Init(InitOptions{ + Level: "info", + Format: "json", + Output: OutputOptions{ + ToStdout: true, + ToFile: true, + FilePath: filepath.Join(os.DevNull, "logs", "sub2api.log"), + }, + Rotation: RotationOptions{ + MaxSizeMB: 10, + MaxBackups: 1, + MaxAgeDays: 1, + }, + }) + if err != nil { + t.Fatalf("Init() should downgrade instead of failing, got: %v", err) + } + + _ = stderrW.Close() + stderrBytes, _ := io.ReadAll(stderrR) + if !strings.Contains(string(stderrBytes), "日志文件输出初始化失败") { + t.Fatalf("stderr should contain fallback warning, got: %s", string(stderrBytes)) + } +} + +func TestInit_CallerShouldPointToCallsite(t *testing.T) { + origStdout := os.Stdout + origStderr := os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + _, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrW.Close() + }) + + if err := Init(InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Caller: true, + Output: OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: SamplingOptions{Enabled: false}, + }); err != nil { + t.Fatalf("Init() error: %v", err) + } + + L().Info("caller-check") + Sync() + _ = stdoutW.Close() + logBytes, _ := io.ReadAll(stdoutR) + + var line string + for _, item := range strings.Split(string(logBytes), "\n") { + if strings.Contains(item, "caller-check") { + line = item + break + } + } + if line == "" { + t.Fatalf("log output missing caller-check: %s", string(logBytes)) + } + + var payload map[string]any + if err := json.Unmarshal([]byte(line), &payload); err != nil { + t.Fatalf("parse log json failed: %v, line=%s", err, line) + } + caller, _ := payload["caller"].(string) + if !strings.Contains(caller, "logger_test.go:") { + t.Fatalf("caller should point to this test file, got: %s", caller) + } +} diff --git a/backend/internal/pkg/logger/options.go b/backend/internal/pkg/logger/options.go new file mode 100644 index 00000000..efcd701c --- /dev/null +++ b/backend/internal/pkg/logger/options.go @@ -0,0 +1,161 @@ +package logger + +import ( + "os" + "path/filepath" + "strings" + "time" +) + +const ( + // DefaultContainerLogPath 为容器内默认日志文件路径。 + DefaultContainerLogPath = "/app/data/logs/sub2api.log" + defaultLogFilename = "sub2api.log" +) + +type InitOptions struct { + Level string + Format string + ServiceName string + Environment string + Caller bool + StacktraceLevel string + Output OutputOptions + Rotation RotationOptions + Sampling SamplingOptions +} + +type OutputOptions struct { + ToStdout bool + ToFile bool + FilePath string +} + +type RotationOptions struct { + MaxSizeMB int + MaxBackups int + MaxAgeDays int + Compress bool + LocalTime bool +} + +type SamplingOptions struct { + Enabled bool + Initial int + Thereafter int +} + +func (o InitOptions) normalized() InitOptions { + out := o + out.Level = strings.ToLower(strings.TrimSpace(out.Level)) + if out.Level == "" { + out.Level = "info" + } + out.Format = strings.ToLower(strings.TrimSpace(out.Format)) + if out.Format == "" { + out.Format = "console" + } + out.ServiceName = strings.TrimSpace(out.ServiceName) + if out.ServiceName == "" { + out.ServiceName = "sub2api" + } + out.Environment = strings.TrimSpace(out.Environment) + if out.Environment == "" { + out.Environment = "production" + } + out.StacktraceLevel = strings.ToLower(strings.TrimSpace(out.StacktraceLevel)) + if out.StacktraceLevel == "" { + out.StacktraceLevel = "error" + } + if !out.Output.ToStdout && !out.Output.ToFile { + out.Output.ToStdout = true + } + out.Output.FilePath = resolveLogFilePath(out.Output.FilePath) + if out.Rotation.MaxSizeMB <= 0 { + out.Rotation.MaxSizeMB = 100 + } + if out.Rotation.MaxBackups < 0 { + out.Rotation.MaxBackups = 10 + } + if out.Rotation.MaxAgeDays < 0 { + out.Rotation.MaxAgeDays = 7 + } + if out.Sampling.Enabled { + if out.Sampling.Initial <= 0 { + out.Sampling.Initial = 100 + } + if out.Sampling.Thereafter <= 0 { + out.Sampling.Thereafter = 100 + } + } + return out +} + +func resolveLogFilePath(explicit string) string { + explicit = strings.TrimSpace(explicit) + if explicit != "" { + return explicit + } + dataDir := strings.TrimSpace(os.Getenv("DATA_DIR")) + if dataDir != "" { + return filepath.Join(dataDir, "logs", defaultLogFilename) + } + return DefaultContainerLogPath +} + +func bootstrapOptions() InitOptions { + return InitOptions{ + Level: "info", + Format: "console", + ServiceName: "sub2api", + Environment: "bootstrap", + Output: OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Rotation: RotationOptions{ + MaxSizeMB: 100, + MaxBackups: 10, + MaxAgeDays: 7, + Compress: true, + LocalTime: true, + }, + Sampling: SamplingOptions{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + } +} + +func parseLevel(level string) (Level, bool) { + switch strings.ToLower(strings.TrimSpace(level)) { + case "debug": + return LevelDebug, true + case "info": + return LevelInfo, true + case "warn": + return LevelWarn, true + case "error": + return LevelError, true + default: + return LevelInfo, false + } +} + +func parseStacktraceLevel(level string) (Level, bool) { + switch strings.ToLower(strings.TrimSpace(level)) { + case "none": + return LevelFatal + 1, true + case "error": + return LevelError, true + case "fatal": + return LevelFatal, true + default: + return LevelError, false + } +} + +func samplingTick() time.Duration { + return time.Second +} diff --git a/backend/internal/pkg/logger/options_test.go b/backend/internal/pkg/logger/options_test.go new file mode 100644 index 00000000..10d50d72 --- /dev/null +++ b/backend/internal/pkg/logger/options_test.go @@ -0,0 +1,102 @@ +package logger + +import ( + "os" + "path/filepath" + "testing" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func TestResolveLogFilePath_Default(t *testing.T) { + t.Setenv("DATA_DIR", "") + got := resolveLogFilePath("") + if got != DefaultContainerLogPath { + t.Fatalf("resolveLogFilePath() = %q, want %q", got, DefaultContainerLogPath) + } +} + +func TestResolveLogFilePath_WithDataDir(t *testing.T) { + t.Setenv("DATA_DIR", "/tmp/sub2api-data") + got := resolveLogFilePath("") + want := filepath.Join("/tmp/sub2api-data", "logs", "sub2api.log") + if got != want { + t.Fatalf("resolveLogFilePath() = %q, want %q", got, want) + } +} + +func TestResolveLogFilePath_ExplicitPath(t *testing.T) { + t.Setenv("DATA_DIR", "/tmp/ignore") + got := resolveLogFilePath("/var/log/custom.log") + if got != "/var/log/custom.log" { + t.Fatalf("resolveLogFilePath() = %q, want explicit path", got) + } +} + +func TestNormalizedOptions_InvalidFallback(t *testing.T) { + t.Setenv("DATA_DIR", "") + opts := InitOptions{ + Level: "TRACE", + Format: "TEXT", + ServiceName: "", + Environment: "", + StacktraceLevel: "panic", + Output: OutputOptions{ + ToStdout: false, + ToFile: false, + }, + Rotation: RotationOptions{ + MaxSizeMB: 0, + MaxBackups: -1, + MaxAgeDays: -1, + }, + Sampling: SamplingOptions{ + Enabled: true, + Initial: 0, + Thereafter: 0, + }, + } + out := opts.normalized() + if out.Level != "trace" { + // normalized 仅做 trim/lower,不做校验;校验在 config 层。 + t.Fatalf("normalized level should preserve value for upstream validation, got %q", out.Level) + } + if !out.Output.ToStdout { + t.Fatalf("normalized output should fallback to stdout") + } + if out.Output.FilePath != DefaultContainerLogPath { + t.Fatalf("normalized file path = %q", out.Output.FilePath) + } + if out.Rotation.MaxSizeMB != 100 { + t.Fatalf("normalized max_size_mb = %d", out.Rotation.MaxSizeMB) + } + if out.Rotation.MaxBackups != 10 { + t.Fatalf("normalized max_backups = %d", out.Rotation.MaxBackups) + } + if out.Rotation.MaxAgeDays != 7 { + t.Fatalf("normalized max_age_days = %d", out.Rotation.MaxAgeDays) + } + if out.Sampling.Initial != 100 || out.Sampling.Thereafter != 100 { + t.Fatalf("normalized sampling defaults invalid: %+v", out.Sampling) + } +} + +func TestBuildFileCore_InvalidPathFallback(t *testing.T) { + t.Setenv("DATA_DIR", "") + opts := bootstrapOptions() + opts.Output.ToFile = true + opts.Output.FilePath = filepath.Join(os.DevNull, "logs", "sub2api.log") + encoderCfg := zapcore.EncoderConfig{ + TimeKey: "time", + LevelKey: "level", + MessageKey: "msg", + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeLevel: zapcore.CapitalLevelEncoder, + } + encoder := zapcore.NewJSONEncoder(encoderCfg) + _, _, err := buildFileCore(encoder, zap.NewAtomicLevel(), opts) + if err == nil { + t.Fatalf("buildFileCore() expected error for invalid path") + } +} diff --git a/backend/internal/pkg/logger/slog_handler.go b/backend/internal/pkg/logger/slog_handler.go new file mode 100644 index 00000000..602ca1e0 --- /dev/null +++ b/backend/internal/pkg/logger/slog_handler.go @@ -0,0 +1,131 @@ +package logger + +import ( + "context" + "log/slog" + "strings" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type slogZapHandler struct { + logger *zap.Logger + attrs []slog.Attr + groups []string +} + +func newSlogZapHandler(logger *zap.Logger) slog.Handler { + if logger == nil { + logger = zap.NewNop() + } + return &slogZapHandler{ + logger: logger, + attrs: make([]slog.Attr, 0, 8), + groups: make([]string, 0, 4), + } +} + +func (h *slogZapHandler) Enabled(_ context.Context, level slog.Level) bool { + switch { + case level >= slog.LevelError: + return h.logger.Core().Enabled(LevelError) + case level >= slog.LevelWarn: + return h.logger.Core().Enabled(LevelWarn) + case level <= slog.LevelDebug: + return h.logger.Core().Enabled(LevelDebug) + default: + return h.logger.Core().Enabled(LevelInfo) + } +} + +func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error { + fields := make([]zap.Field, 0, len(h.attrs)+record.NumAttrs()+3) + fields = append(fields, slogAttrsToZapFields(h.groups, h.attrs)...) + record.Attrs(func(attr slog.Attr) bool { + fields = append(fields, slogAttrToZapField(h.groups, attr)) + return true + }) + + switch { + case record.Level >= slog.LevelError: + h.logger.Error(record.Message, fields...) + case record.Level >= slog.LevelWarn: + h.logger.Warn(record.Message, fields...) + case record.Level <= slog.LevelDebug: + h.logger.Debug(record.Message, fields...) + default: + h.logger.Info(record.Message, fields...) + } + return nil +} + +func (h *slogZapHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + next := *h + next.attrs = append(append([]slog.Attr{}, h.attrs...), attrs...) + return &next +} + +func (h *slogZapHandler) WithGroup(name string) slog.Handler { + name = strings.TrimSpace(name) + if name == "" { + return h + } + next := *h + next.groups = append(append([]string{}, h.groups...), name) + return &next +} + +func slogAttrsToZapFields(groups []string, attrs []slog.Attr) []zap.Field { + fields := make([]zap.Field, 0, len(attrs)) + for _, attr := range attrs { + fields = append(fields, slogAttrToZapField(groups, attr)) + } + return fields +} + +func slogAttrToZapField(groups []string, attr slog.Attr) zap.Field { + if len(groups) > 0 { + attr.Key = strings.Join(append(append([]string{}, groups...), attr.Key), ".") + } + value := attr.Value.Resolve() + switch value.Kind() { + case slog.KindBool: + return zap.Bool(attr.Key, value.Bool()) + case slog.KindInt64: + return zap.Int64(attr.Key, value.Int64()) + case slog.KindUint64: + return zap.Uint64(attr.Key, value.Uint64()) + case slog.KindFloat64: + return zap.Float64(attr.Key, value.Float64()) + case slog.KindDuration: + return zap.Duration(attr.Key, value.Duration()) + case slog.KindTime: + return zap.Time(attr.Key, value.Time()) + case slog.KindString: + return zap.String(attr.Key, value.String()) + case slog.KindGroup: + groupFields := make([]zap.Field, 0, len(value.Group())) + for _, nested := range value.Group() { + groupFields = append(groupFields, slogAttrToZapField(nil, nested)) + } + return zap.Object(attr.Key, zapObjectFields(groupFields)) + case slog.KindAny: + if t, ok := value.Any().(time.Time); ok { + return zap.Time(attr.Key, t) + } + return zap.Any(attr.Key, value.Any()) + default: + return zap.String(attr.Key, value.String()) + } +} + +type zapObjectFields []zap.Field + +func (z zapObjectFields) MarshalLogObject(enc zapcore.ObjectEncoder) error { + for _, field := range z { + field.AddTo(enc) + } + return nil +} diff --git a/backend/internal/pkg/logger/slog_handler_test.go b/backend/internal/pkg/logger/slog_handler_test.go new file mode 100644 index 00000000..d2b4208d --- /dev/null +++ b/backend/internal/pkg/logger/slog_handler_test.go @@ -0,0 +1,88 @@ +package logger + +import ( + "context" + "log/slog" + "testing" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type captureState struct { + writes []capturedWrite +} + +type capturedWrite struct { + fields []zapcore.Field +} + +type captureCore struct { + state *captureState + withFields []zapcore.Field +} + +func newCaptureCore() *captureCore { + return &captureCore{state: &captureState{}} +} + +func (c *captureCore) Enabled(zapcore.Level) bool { + return true +} + +func (c *captureCore) With(fields []zapcore.Field) zapcore.Core { + nextFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields)) + nextFields = append(nextFields, c.withFields...) + nextFields = append(nextFields, fields...) + return &captureCore{ + state: c.state, + withFields: nextFields, + } +} + +func (c *captureCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry { + return ce.AddCore(entry, c) +} + +func (c *captureCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { + allFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields)) + allFields = append(allFields, c.withFields...) + allFields = append(allFields, fields...) + c.state.writes = append(c.state.writes, capturedWrite{ + fields: allFields, + }) + return nil +} + +func (c *captureCore) Sync() error { + return nil +} + +func TestSlogZapHandler_Handle_DoesNotAppendTimeField(t *testing.T) { + core := newCaptureCore() + handler := newSlogZapHandler(zap.New(core)) + + record := slog.NewRecord(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC), slog.LevelInfo, "hello", 0) + record.AddAttrs(slog.String("component", "http.access")) + + if err := handler.Handle(context.Background(), record); err != nil { + t.Fatalf("handle slog record: %v", err) + } + if len(core.state.writes) != 1 { + t.Fatalf("write calls = %d, want 1", len(core.state.writes)) + } + + var hasComponent bool + for _, field := range core.state.writes[0].fields { + if field.Key == "time" { + t.Fatalf("unexpected duplicate time field in slog adapter output") + } + if field.Key == "component" { + hasComponent = true + } + } + if !hasComponent { + t.Fatalf("component field should be preserved") + } +} diff --git a/backend/internal/pkg/logger/stdlog_bridge_test.go b/backend/internal/pkg/logger/stdlog_bridge_test.go new file mode 100644 index 00000000..4482a2ec --- /dev/null +++ b/backend/internal/pkg/logger/stdlog_bridge_test.go @@ -0,0 +1,166 @@ +package logger + +import ( + "io" + "log" + "os" + "strings" + "testing" +) + +func TestInferStdLogLevel(t *testing.T) { + cases := []struct { + msg string + want Level + }{ + {msg: "Warning: queue full", want: LevelWarn}, + {msg: "Forward request failed: timeout", want: LevelError}, + {msg: "[ERROR] upstream unavailable", want: LevelError}, + {msg: "[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5", want: LevelInfo}, + {msg: "service started", want: LevelInfo}, + {msg: "debug: cache miss", want: LevelDebug}, + } + + for _, tc := range cases { + got := inferStdLogLevel(tc.msg) + if got != tc.want { + t.Fatalf("inferStdLogLevel(%q)=%v want=%v", tc.msg, got, tc.want) + } + } +} + +func TestNormalizeStdLogMessage(t *testing.T) { + raw := " [TokenRefresh] cycle complete \n total=1 failed=0 \n" + got := normalizeStdLogMessage(raw) + want := "[TokenRefresh] cycle complete total=1 failed=0" + if got != want { + t.Fatalf("normalizeStdLogMessage()=%q want=%q", got, want) + } +} + +func TestStdLogBridgeRoutesLevels(t *testing.T) { + origStdout := os.Stdout + origStderr := os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + }) + + if err := Init(InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: SamplingOptions{Enabled: false}, + }); err != nil { + t.Fatalf("Init() error: %v", err) + } + + log.Printf("service started") + log.Printf("Warning: queue full") + log.Printf("Forward request failed: timeout") + Sync() + + _ = stdoutW.Close() + _ = stderrW.Close() + stdoutBytes, _ := io.ReadAll(stdoutR) + stderrBytes, _ := io.ReadAll(stderrR) + stdoutText := string(stdoutBytes) + stderrText := string(stderrBytes) + + if !strings.Contains(stdoutText, "service started") { + t.Fatalf("stdout missing info log: %s", stdoutText) + } + if !strings.Contains(stderrText, "Warning: queue full") { + t.Fatalf("stderr missing warn log: %s", stderrText) + } + if !strings.Contains(stderrText, "Forward request failed: timeout") { + t.Fatalf("stderr missing error log: %s", stderrText) + } + if !strings.Contains(stderrText, "\"legacy_stdlog\":true") { + t.Fatalf("stderr missing legacy_stdlog marker: %s", stderrText) + } +} + +func TestLegacyPrintfRoutesLevels(t *testing.T) { + origStdout := os.Stdout + origStderr := os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + }) + + if err := Init(InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: SamplingOptions{Enabled: false}, + }); err != nil { + t.Fatalf("Init() error: %v", err) + } + + LegacyPrintf("service.test", "request started") + LegacyPrintf("service.test", "Warning: queue full") + LegacyPrintf("service.test", "forward failed: timeout") + Sync() + + _ = stdoutW.Close() + _ = stderrW.Close() + stdoutBytes, _ := io.ReadAll(stdoutR) + stderrBytes, _ := io.ReadAll(stderrR) + stdoutText := string(stdoutBytes) + stderrText := string(stderrBytes) + + if !strings.Contains(stdoutText, "request started") { + t.Fatalf("stdout missing info log: %s", stdoutText) + } + if !strings.Contains(stderrText, "Warning: queue full") { + t.Fatalf("stderr missing warn log: %s", stderrText) + } + if !strings.Contains(stderrText, "forward failed: timeout") { + t.Fatalf("stderr missing error log: %s", stderrText) + } + if !strings.Contains(stderrText, "\"legacy_printf\":true") { + t.Fatalf("stderr missing legacy_printf marker: %s", stderrText) + } + if !strings.Contains(stderrText, "\"component\":\"service.test\"") { + t.Fatalf("stderr missing component field: %s", stderrText) + } +} diff --git a/backend/internal/pkg/oauth/oauth.go b/backend/internal/pkg/oauth/oauth.go index 33caffd7..cfc91bee 100644 --- a/backend/internal/pkg/oauth/oauth.go +++ b/backend/internal/pkg/oauth/oauth.go @@ -50,6 +50,7 @@ type OAuthSession struct { type SessionStore struct { mu sync.RWMutex sessions map[string]*OAuthSession + stopOnce sync.Once stopCh chan struct{} } @@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore { // Stop stops the cleanup goroutine func (s *SessionStore) Stop() { - close(s.stopCh) + s.stopOnce.Do(func() { + close(s.stopCh) + }) } // Set stores a session diff --git a/backend/internal/pkg/oauth/oauth_test.go b/backend/internal/pkg/oauth/oauth_test.go new file mode 100644 index 00000000..9e59f0f0 --- /dev/null +++ b/backend/internal/pkg/oauth/oauth_test.go @@ -0,0 +1,43 @@ +package oauth + +import ( + "sync" + "testing" + "time" +) + +func TestSessionStore_Stop_Idempotent(t *testing.T) { + store := NewSessionStore() + + store.Stop() + store.Stop() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} + +func TestSessionStore_Stop_Concurrent(t *testing.T) { + store := NewSessionStore() + + var wg sync.WaitGroup + for range 50 { + wg.Add(1) + go func() { + defer wg.Done() + store.Stop() + }() + } + + wg.Wait() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index fd24b11d..4bbc68e7 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -15,8 +15,8 @@ type Model struct { // DefaultModels OpenAI models list var DefaultModels = []Model{ - {ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"}, {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, + {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, {ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"}, {ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"}, diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index df972a13..8bdcbe16 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -17,6 +17,8 @@ import ( const ( // OAuth Client ID for OpenAI (Codex CLI official) ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + // OAuth Client ID for Sora mobile flow (aligned with sora2api) + SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK" // OAuth endpoints AuthorizeURL = "https://auth.openai.com/oauth/authorize" @@ -34,10 +36,18 @@ const ( SessionTTL = 30 * time.Minute ) +const ( + // OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client. + OAuthPlatformOpenAI = "openai" + // OAuthPlatformSora uses Sora OAuth client. + OAuthPlatformSora = "sora" +) + // OAuthSession stores OAuth flow state for OpenAI type OAuthSession struct { State string `json:"state"` CodeVerifier string `json:"code_verifier"` + ClientID string `json:"client_id,omitempty"` ProxyURL string `json:"proxy_url,omitempty"` RedirectURI string `json:"redirect_uri"` CreatedAt time.Time `json:"created_at"` @@ -47,6 +57,7 @@ type OAuthSession struct { type SessionStore struct { mu sync.RWMutex sessions map[string]*OAuthSession + stopOnce sync.Once stopCh chan struct{} } @@ -92,7 +103,9 @@ func (s *SessionStore) Delete(sessionID string) { // Stop stops the cleanup goroutine func (s *SessionStore) Stop() { - close(s.stopCh) + s.stopOnce.Do(func() { + close(s.stopCh) + }) } // cleanup removes expired sessions periodically @@ -169,13 +182,20 @@ func base64URLEncode(data []byte) string { // BuildAuthorizationURL builds the OpenAI OAuth authorization URL func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string { + return BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, OAuthPlatformOpenAI) +} + +// BuildAuthorizationURLForPlatform builds authorization URL by platform. +func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platform string) string { if redirectURI == "" { redirectURI = DefaultRedirectURI } + clientID, codexFlow := OAuthClientConfigByPlatform(platform) + params := url.Values{} params.Set("response_type", "code") - params.Set("client_id", ClientID) + params.Set("client_id", clientID) params.Set("redirect_uri", redirectURI) params.Set("scope", DefaultScopes) params.Set("state", state) @@ -183,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string { params.Set("code_challenge_method", "S256") // OpenAI specific parameters params.Set("id_token_add_organizations", "true") - params.Set("codex_cli_simplified_flow", "true") + if codexFlow { + params.Set("codex_cli_simplified_flow", "true") + } return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()) } +// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled. +// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri), +// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。 +func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) { + switch strings.ToLower(strings.TrimSpace(platform)) { + case OAuthPlatformSora: + return ClientID, false + default: + return ClientID, true + } +} + // TokenRequest represents the token exchange request body type TokenRequest struct { GrantType string `json:"grant_type"` @@ -291,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string { return params.Encode() } -// ParseIDToken parses the ID Token JWT and extracts claims -// Note: This does NOT verify the signature - it only decodes the payload -// For production, you should verify the token signature using OpenAI's public keys +// ParseIDToken parses the ID Token JWT and extracts claims. +// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。 +// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名: +// +// https://auth.openai.com/.well-known/jwks.json func ParseIDToken(idToken string) (*IDTokenClaims, error) { parts := strings.Split(idToken, ".") if len(parts) != 3 { @@ -324,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) { return nil, fmt.Errorf("failed to parse JWT claims: %w", err) } + // 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌) + const clockSkewTolerance = 120 // 秒 + now := time.Now().Unix() + if claims.Exp > 0 && now > claims.Exp+clockSkewTolerance { + return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance) + } + return &claims, nil } diff --git a/backend/internal/pkg/openai/oauth_test.go b/backend/internal/pkg/openai/oauth_test.go new file mode 100644 index 00000000..2970addf --- /dev/null +++ b/backend/internal/pkg/openai/oauth_test.go @@ -0,0 +1,82 @@ +package openai + +import ( + "net/url" + "sync" + "testing" + "time" +) + +func TestSessionStore_Stop_Idempotent(t *testing.T) { + store := NewSessionStore() + + store.Stop() + store.Stop() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} + +func TestSessionStore_Stop_Concurrent(t *testing.T) { + store := NewSessionStore() + + var wg sync.WaitGroup + for range 50 { + wg.Add(1) + go func() { + defer wg.Done() + store.Stop() + }() + } + + wg.Wait() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} + +func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) { + authURL := BuildAuthorizationURLForPlatform("state-1", "challenge-1", DefaultRedirectURI, OAuthPlatformOpenAI) + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("Parse URL failed: %v", err) + } + q := parsed.Query() + if got := q.Get("client_id"); got != ClientID { + t.Fatalf("client_id mismatch: got=%q want=%q", got, ClientID) + } + if got := q.Get("codex_cli_simplified_flow"); got != "true" { + t.Fatalf("codex flow mismatch: got=%q want=true", got) + } + if got := q.Get("id_token_add_organizations"); got != "true" { + t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) + } +} + +// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id, +// 但不启用 codex_cli_simplified_flow。 +func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) { + authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora) + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("Parse URL failed: %v", err) + } + q := parsed.Query() + if got := q.Get("client_id"); got != ClientID { + t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID) + } + if got := q.Get("codex_cli_simplified_flow"); got != "" { + t.Fatalf("codex flow should be empty for sora, got=%q", got) + } + if got := q.Get("id_token_add_organizations"); got != "true" { + t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) + } +} diff --git a/backend/internal/pkg/openai/request.go b/backend/internal/pkg/openai/request.go index 5b049ddc..c24d1273 100644 --- a/backend/internal/pkg/openai/request.go +++ b/backend/internal/pkg/openai/request.go @@ -1,5 +1,7 @@ package openai +import "strings" + // CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns // Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2" var CodexCLIUserAgentPrefixes = []string{ @@ -7,10 +9,67 @@ var CodexCLIUserAgentPrefixes = []string{ "codex_cli_rs/", } +// CodexOfficialClientUserAgentPrefixes matches Codex 官方客户端家族 User-Agent 前缀。 +// 该列表仅用于 OpenAI OAuth `codex_cli_only` 访问限制判定。 +var CodexOfficialClientUserAgentPrefixes = []string{ + "codex_cli_rs/", + "codex_vscode/", + "codex_app/", + "codex_chatgpt_desktop/", + "codex_atlas/", + "codex_exec/", + "codex_sdk_ts/", + "codex ", +} + +// CodexOfficialClientOriginatorPrefixes matches Codex 官方客户端家族 originator 前缀。 +// 说明:OpenAI 官方 Codex 客户端并不只使用固定的 codex_app 标识。 +// 例如 codex_cli_rs、codex_vscode、codex_chatgpt_desktop、codex_atlas、codex_exec、codex_sdk_ts 等。 +var CodexOfficialClientOriginatorPrefixes = []string{ + "codex_", + "codex ", +} + // IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request func IsCodexCLIRequest(userAgent string) bool { - for _, prefix := range CodexCLIUserAgentPrefixes { - if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix { + ua := normalizeCodexClientHeader(userAgent) + if ua == "" { + return false + } + return matchCodexClientHeaderPrefixes(ua, CodexCLIUserAgentPrefixes) +} + +// IsCodexOfficialClientRequest checks if the User-Agent indicates a Codex 官方客户端请求。 +// 与 IsCodexCLIRequest 解耦,避免影响历史兼容逻辑。 +func IsCodexOfficialClientRequest(userAgent string) bool { + ua := normalizeCodexClientHeader(userAgent) + if ua == "" { + return false + } + return matchCodexClientHeaderPrefixes(ua, CodexOfficialClientUserAgentPrefixes) +} + +// IsCodexOfficialClientOriginator checks if originator indicates a Codex 官方客户端请求。 +func IsCodexOfficialClientOriginator(originator string) bool { + v := normalizeCodexClientHeader(originator) + if v == "" { + return false + } + return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes) +} + +func normalizeCodexClientHeader(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} + +func matchCodexClientHeaderPrefixes(value string, prefixes []string) bool { + for _, prefix := range prefixes { + normalizedPrefix := normalizeCodexClientHeader(prefix) + if normalizedPrefix == "" { + continue + } + // 优先前缀匹配;若 UA/Originator 被网关拼接为复合字符串时,退化为包含匹配。 + if strings.HasPrefix(value, normalizedPrefix) || strings.Contains(value, normalizedPrefix) { return true } } diff --git a/backend/internal/pkg/openai/request_test.go b/backend/internal/pkg/openai/request_test.go new file mode 100644 index 00000000..508bf561 --- /dev/null +++ b/backend/internal/pkg/openai/request_test.go @@ -0,0 +1,87 @@ +package openai + +import "testing" + +func TestIsCodexCLIRequest(t *testing.T) { + tests := []struct { + name string + ua string + want bool + }{ + {name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.1.0", want: true}, + {name: "codex_vscode 前缀", ua: "codex_vscode/1.2.3", want: true}, + {name: "大小写混合", ua: "Codex_CLI_Rs/0.1.0", want: true}, + {name: "复合 UA 包含 codex", ua: "Mozilla/5.0 codex_cli_rs/0.1.0", want: true}, + {name: "空白包裹", ua: " codex_vscode/1.2.3 ", want: true}, + {name: "非 codex", ua: "curl/8.0.1", want: false}, + {name: "空字符串", ua: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexCLIRequest(tt.ua) + if got != tt.want { + t.Fatalf("IsCodexCLIRequest(%q) = %v, want %v", tt.ua, got, tt.want) + } + }) + } +} + +func TestIsCodexOfficialClientRequest(t *testing.T) { + tests := []struct { + name string + ua string + want bool + }{ + {name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.98.0", want: true}, + {name: "codex_vscode 前缀", ua: "codex_vscode/1.0.0", want: true}, + {name: "codex_app 前缀", ua: "codex_app/0.1.0", want: true}, + {name: "codex_chatgpt_desktop 前缀", ua: "codex_chatgpt_desktop/1.0.0", want: true}, + {name: "codex_atlas 前缀", ua: "codex_atlas/1.0.0", want: true}, + {name: "codex_exec 前缀", ua: "codex_exec/0.1.0", want: true}, + {name: "codex_sdk_ts 前缀", ua: "codex_sdk_ts/0.1.0", want: true}, + {name: "Codex 桌面 UA", ua: "Codex Desktop/1.2.3", want: true}, + {name: "复合 UA 包含 codex_app", ua: "Mozilla/5.0 codex_app/0.1.0", want: true}, + {name: "大小写混合", ua: "Codex_VSCode/1.2.3", want: true}, + {name: "非 codex", ua: "curl/8.0.1", want: false}, + {name: "空字符串", ua: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexOfficialClientRequest(tt.ua) + if got != tt.want { + t.Fatalf("IsCodexOfficialClientRequest(%q) = %v, want %v", tt.ua, got, tt.want) + } + }) + } +} + +func TestIsCodexOfficialClientOriginator(t *testing.T) { + tests := []struct { + name string + originator string + want bool + }{ + {name: "codex_cli_rs", originator: "codex_cli_rs", want: true}, + {name: "codex_vscode", originator: "codex_vscode", want: true}, + {name: "codex_app", originator: "codex_app", want: true}, + {name: "codex_chatgpt_desktop", originator: "codex_chatgpt_desktop", want: true}, + {name: "codex_atlas", originator: "codex_atlas", want: true}, + {name: "codex_exec", originator: "codex_exec", want: true}, + {name: "codex_sdk_ts", originator: "codex_sdk_ts", want: true}, + {name: "Codex 前缀", originator: "Codex Desktop", want: true}, + {name: "空白包裹", originator: " codex_vscode ", want: true}, + {name: "非 codex", originator: "my_client", want: false}, + {name: "空字符串", originator: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexOfficialClientOriginator(tt.originator) + if got != tt.want { + t.Fatalf("IsCodexOfficialClientOriginator(%q) = %v, want %v", tt.originator, got, tt.want) + } + }) + } +} diff --git a/backend/internal/pkg/proxyurl/parse.go b/backend/internal/pkg/proxyurl/parse.go new file mode 100644 index 00000000..217556f2 --- /dev/null +++ b/backend/internal/pkg/proxyurl/parse.go @@ -0,0 +1,66 @@ +// Package proxyurl 提供代理 URL 的统一验证(fail-fast,无效代理不回退直连) +// +// 所有需要解析代理 URL 的地方必须通过此包的 Parse 函数。 +// 直接使用 url.Parse 处理代理 URL 是被禁止的。 +// 这确保了 fail-fast 行为:无效代理配置在创建时立即失败, +// 而不是在运行时静默回退到直连(产生 IP 关联风险)。 +package proxyurl + +import ( + "fmt" + "net/url" + "strings" +) + +// allowedSchemes 代理协议白名单 +var allowedSchemes = map[string]bool{ + "http": true, + "https": true, + "socks5": true, + "socks5h": true, +} + +// Parse 解析并验证代理 URL。 +// +// 语义: +// - 空字符串 → ("", nil, nil),表示直连 +// - 非空且有效 → (trimmed, *url.URL, nil) +// - 非空但无效 → ("", nil, error),fail-fast 不回退 +// +// 验证规则: +// - TrimSpace 后为空视为直连 +// - url.Parse 失败返回 error(不含原始 URL,防凭据泄露) +// - Host 为空返回 error(用 Redacted() 脱敏) +// - Scheme 必须为 http/https/socks5/socks5h +// - socks5:// 自动升级为 socks5h://(确保 DNS 由代理端解析,防止 DNS 泄漏) +func Parse(raw string) (trimmed string, parsed *url.URL, err error) { + trimmed = strings.TrimSpace(raw) + if trimmed == "" { + return "", nil, nil + } + + parsed, err = url.Parse(trimmed) + if err != nil { + // 不使用 %w 包装,避免 url.Parse 的底层错误消息泄漏原始 URL(可能含凭据) + return "", nil, fmt.Errorf("invalid proxy URL: %v", err) + } + + if parsed.Host == "" || parsed.Hostname() == "" { + return "", nil, fmt.Errorf("proxy URL missing host: %s", parsed.Redacted()) + } + + scheme := strings.ToLower(parsed.Scheme) + if !allowedSchemes[scheme] { + return "", nil, fmt.Errorf("unsupported proxy scheme %q (allowed: http, https, socks5, socks5h)", scheme) + } + + // 自动升级 socks5 → socks5h,确保 DNS 由代理端解析,防止 DNS 泄漏。 + // Go 的 golang.org/x/net/proxy 对 socks5:// 默认在客户端本地解析 DNS, + // 仅 socks5h:// 才将域名发送给代理端做远程 DNS 解析。 + if scheme == "socks5" { + parsed.Scheme = "socks5h" + trimmed = parsed.String() + } + + return trimmed, parsed, nil +} diff --git a/backend/internal/pkg/proxyurl/parse_test.go b/backend/internal/pkg/proxyurl/parse_test.go new file mode 100644 index 00000000..5fb57c16 --- /dev/null +++ b/backend/internal/pkg/proxyurl/parse_test.go @@ -0,0 +1,215 @@ +package proxyurl + +import ( + "strings" + "testing" +) + +func TestParse_空字符串直连(t *testing.T) { + trimmed, parsed, err := Parse("") + if err != nil { + t.Fatalf("空字符串应直连: %v", err) + } + if trimmed != "" { + t.Errorf("trimmed 应为空: got %q", trimmed) + } + if parsed != nil { + t.Errorf("parsed 应为 nil: got %v", parsed) + } +} + +func TestParse_空白字符串直连(t *testing.T) { + trimmed, parsed, err := Parse(" ") + if err != nil { + t.Fatalf("空白字符串应直连: %v", err) + } + if trimmed != "" { + t.Errorf("trimmed 应为空: got %q", trimmed) + } + if parsed != nil { + t.Errorf("parsed 应为 nil: got %v", parsed) + } +} + +func TestParse_有效HTTP代理(t *testing.T) { + trimmed, parsed, err := Parse("http://proxy.example.com:8080") + if err != nil { + t.Fatalf("有效 HTTP 代理应成功: %v", err) + } + if trimmed != "http://proxy.example.com:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } + if parsed == nil { + t.Fatal("parsed 不应为 nil") + } + if parsed.Host != "proxy.example.com:8080" { + t.Errorf("Host 不匹配: got %q", parsed.Host) + } +} + +func TestParse_有效HTTPS代理(t *testing.T) { + _, parsed, err := Parse("https://proxy.example.com:443") + if err != nil { + t.Fatalf("有效 HTTPS 代理应成功: %v", err) + } + if parsed.Scheme != "https" { + t.Errorf("Scheme 不匹配: got %q", parsed.Scheme) + } +} + +func TestParse_有效SOCKS5代理_自动升级为SOCKS5H(t *testing.T) { + trimmed, parsed, err := Parse("socks5://127.0.0.1:1080") + if err != nil { + t.Fatalf("有效 SOCKS5 代理应成功: %v", err) + } + // socks5 自动升级为 socks5h,确保 DNS 由代理端解析 + if trimmed != "socks5h://127.0.0.1:1080" { + t.Errorf("trimmed 应升级为 socks5h: got %q", trimmed) + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } +} + +func TestParse_无效URL(t *testing.T) { + _, _, err := Parse("://invalid") + if err == nil { + t.Fatal("无效 URL 应返回错误") + } + if !strings.Contains(err.Error(), "invalid proxy URL") { + t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error()) + } +} + +func TestParse_缺少Host(t *testing.T) { + _, _, err := Parse("http://") + if err == nil { + t.Fatal("缺少 host 应返回错误") + } + if !strings.Contains(err.Error(), "missing host") { + t.Errorf("错误信息应包含 'missing host': got %s", err.Error()) + } +} + +func TestParse_不支持的Scheme(t *testing.T) { + _, _, err := Parse("ftp://proxy.example.com:21") + if err == nil { + t.Fatal("不支持的 scheme 应返回错误") + } + if !strings.Contains(err.Error(), "unsupported proxy scheme") { + t.Errorf("错误信息应包含 'unsupported proxy scheme': got %s", err.Error()) + } +} + +func TestParse_含密码URL脱敏(t *testing.T) { + // 场景 1: 带密码的 socks5 URL 应成功解析并升级为 socks5h + trimmed, parsed, err := Parse("socks5://user:secret_password@proxy.local:1080") + if err != nil { + t.Fatalf("含密码的有效 URL 应成功: %v", err) + } + if trimmed == "" || parsed == nil { + t.Fatal("应返回非空结果") + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } + if !strings.HasPrefix(trimmed, "socks5h://") { + t.Errorf("trimmed 应以 socks5h:// 开头: got %q", trimmed) + } + if parsed.User == nil { + t.Error("升级后应保留 UserInfo") + } + + // 场景 2: 带密码但缺少 host(触发 Redacted 脱敏路径) + _, _, err = Parse("http://user:secret_password@:0/") + if err == nil { + t.Fatal("缺少 host 应返回错误") + } + if strings.Contains(err.Error(), "secret_password") { + t.Error("错误信息不应包含明文密码") + } + if !strings.Contains(err.Error(), "missing host") { + t.Errorf("错误信息应包含 'missing host': got %s", err.Error()) + } +} + +func TestParse_带空白的有效URL(t *testing.T) { + trimmed, parsed, err := Parse(" http://proxy.example.com:8080 ") + if err != nil { + t.Fatalf("带空白的有效 URL 应成功: %v", err) + } + if trimmed != "http://proxy.example.com:8080" { + t.Errorf("trimmed 应去除空白: got %q", trimmed) + } + if parsed == nil { + t.Fatal("parsed 不应为 nil") + } +} + +func TestParse_Scheme大小写不敏感(t *testing.T) { + // 大写 SOCKS5 应被接受并升级为 socks5h + trimmed, parsed, err := Parse("SOCKS5://proxy.example.com:1080") + if err != nil { + t.Fatalf("大写 SOCKS5 应被接受: %v", err) + } + if parsed.Scheme != "socks5h" { + t.Errorf("大写 SOCKS5 Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } + if !strings.HasPrefix(trimmed, "socks5h://") { + t.Errorf("大写 SOCKS5 trimmed 应升级为 socks5h://: got %q", trimmed) + } + + // 大写 HTTP 应被接受(不变) + _, _, err = Parse("HTTP://proxy.example.com:8080") + if err != nil { + t.Fatalf("大写 HTTP 应被接受: %v", err) + } +} + +func TestParse_带认证的有效代理(t *testing.T) { + trimmed, parsed, err := Parse("http://user:pass@proxy.example.com:8080") + if err != nil { + t.Fatalf("带认证的代理 URL 应成功: %v", err) + } + if parsed.User == nil { + t.Error("应保留 UserInfo") + } + if trimmed != "http://user:pass@proxy.example.com:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } +} + +func TestParse_IPv6地址(t *testing.T) { + trimmed, parsed, err := Parse("http://[::1]:8080") + if err != nil { + t.Fatalf("IPv6 代理 URL 应成功: %v", err) + } + if parsed.Hostname() != "::1" { + t.Errorf("Hostname 不匹配: got %q", parsed.Hostname()) + } + if trimmed != "http://[::1]:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } +} + +func TestParse_SOCKS5H保持不变(t *testing.T) { + trimmed, parsed, err := Parse("socks5h://proxy.local:1080") + if err != nil { + t.Fatalf("有效 SOCKS5H 代理应成功: %v", err) + } + // socks5h 不需要升级,应保持原样 + if trimmed != "socks5h://proxy.local:1080" { + t.Errorf("trimmed 不应变化: got %q", trimmed) + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应保持 socks5h: got %q", parsed.Scheme) + } +} + +func TestParse_无Scheme裸地址(t *testing.T) { + // 无 scheme 的裸地址,Go url.Parse 将其视为 path,Host 为空 + _, _, err := Parse("proxy.example.com:8080") + if err == nil { + t.Fatal("无 scheme 的裸地址应返回错误") + } +} diff --git a/backend/internal/pkg/proxyutil/dialer.go b/backend/internal/pkg/proxyutil/dialer.go index 91b224a2..e437cae3 100644 --- a/backend/internal/pkg/proxyutil/dialer.go +++ b/backend/internal/pkg/proxyutil/dialer.go @@ -2,7 +2,11 @@ // // 支持的代理协议: // - HTTP/HTTPS: 通过 Transport.Proxy 设置 -// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS) +// - SOCKS5: 通过 Transport.DialContext 设置(客户端本地解析 DNS) +// - SOCKS5H: 通过 Transport.DialContext 设置(代理端远程解析 DNS,推荐) +// +// 注意:proxyurl.Parse() 会自动将 socks5:// 升级为 socks5h://, +// 确保 DNS 也由代理端解析,防止 DNS 泄漏。 package proxyutil import ( @@ -20,7 +24,8 @@ import ( // // 支持的协议: // - http/https: 设置 transport.Proxy -// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS) +// - socks5: 设置 transport.DialContext(客户端本地解析 DNS) +// - socks5h: 设置 transport.DialContext(代理端远程解析 DNS,推荐) // // 参数: // - transport: 需要配置的 http.Transport diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go index c5b41d6e..0519c2cc 100644 --- a/backend/internal/pkg/response/response.go +++ b/backend/internal/pkg/response/response.go @@ -7,6 +7,7 @@ import ( "net/http" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" "github.com/gin-gonic/gin" ) @@ -78,7 +79,7 @@ func ErrorFrom(c *gin.Context, err error) bool { // Log internal errors with full details for debugging if statusCode >= 500 && c.Request != nil { - log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error()) + log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error())) } ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata) diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go index ef31ca3c..0debce5f 100644 --- a/backend/internal/pkg/response/response_test.go +++ b/backend/internal/pkg/response/response_test.go @@ -14,6 +14,44 @@ import ( "github.com/stretchr/testify/require" ) +// ---------- 辅助函数 ---------- + +// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体 +func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response { + t.Helper() + var got Response + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) + return got +} + +// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData) +func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) { + t.Helper() + // 先用 raw json 解析,因为 Data 是 any 类型 + var raw struct { + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason,omitempty"` + Data json.RawMessage `json:"data,omitempty"` + } + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw)) + + var pd PaginatedData + require.NoError(t, json.Unmarshal(raw.Data, &pd)) + + return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd +} + +// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination +func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil) + return w, c +} + +// ---------- 现有测试 ---------- + func TestErrorWithDetails(t *testing.T) { gin.SetMode(gin.TestMode) @@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) { }) } } + +// ---------- 新增测试 ---------- + +func TestSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + data any + wantCode int + wantBody Response + }{ + { + name: "返回字符串数据", + data: "hello", + wantCode: http.StatusOK, + wantBody: Response{Code: 0, Message: "success", Data: "hello"}, + }, + { + name: "返回nil数据", + data: nil, + wantCode: http.StatusOK, + wantBody: Response{Code: 0, Message: "success"}, + }, + { + name: "返回map数据", + data: map[string]string{"key": "value"}, + wantCode: http.StatusOK, + wantBody: Response{Code: 0, Message: "success"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Success(c, tt.data) + + require.Equal(t, tt.wantCode, w.Code) + + // 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice + got := parseResponseBody(t, w) + require.Equal(t, 0, got.Code) + require.Equal(t, "success", got.Message) + + if tt.data == nil { + require.Nil(t, got.Data) + } else { + require.NotNil(t, got.Data) + } + }) + } +} + +func TestCreated(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + data any + wantCode int + }{ + { + name: "创建成功_返回数据", + data: map[string]int{"id": 42}, + wantCode: http.StatusCreated, + }, + { + name: "创建成功_nil数据", + data: nil, + wantCode: http.StatusCreated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Created(c, tt.data) + + require.Equal(t, tt.wantCode, w.Code) + + got := parseResponseBody(t, w) + require.Equal(t, 0, got.Code) + require.Equal(t, "success", got.Message) + }) + } +} + +func TestError(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + statusCode int + message string + }{ + { + name: "400错误", + statusCode: http.StatusBadRequest, + message: "bad request", + }, + { + name: "500错误", + statusCode: http.StatusInternalServerError, + message: "internal error", + }, + { + name: "自定义状态码", + statusCode: 418, + message: "I'm a teapot", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Error(c, tt.statusCode, tt.message) + + require.Equal(t, tt.statusCode, w.Code) + + got := parseResponseBody(t, w) + require.Equal(t, tt.statusCode, got.Code) + require.Equal(t, tt.message, got.Message) + require.Empty(t, got.Reason) + require.Nil(t, got.Metadata) + require.Nil(t, got.Data) + }) + } +} + +func TestBadRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + BadRequest(c, "参数无效") + + require.Equal(t, http.StatusBadRequest, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusBadRequest, got.Code) + require.Equal(t, "参数无效", got.Message) +} + +func TestUnauthorized(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Unauthorized(c, "未登录") + + require.Equal(t, http.StatusUnauthorized, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusUnauthorized, got.Code) + require.Equal(t, "未登录", got.Message) +} + +func TestForbidden(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Forbidden(c, "无权限") + + require.Equal(t, http.StatusForbidden, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusForbidden, got.Code) + require.Equal(t, "无权限", got.Message) +} + +func TestNotFound(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + NotFound(c, "资源不存在") + + require.Equal(t, http.StatusNotFound, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusNotFound, got.Code) + require.Equal(t, "资源不存在", got.Message) +} + +func TestInternalError(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + InternalError(c, "服务器内部错误") + + require.Equal(t, http.StatusInternalServerError, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusInternalServerError, got.Code) + require.Equal(t, "服务器内部错误", got.Message) +} + +func TestPaginated(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + items any + total int64 + page int + pageSize int + wantPages int + wantTotal int64 + wantPage int + wantPageSize int + }{ + { + name: "标准分页_多页", + items: []string{"a", "b"}, + total: 25, + page: 1, + pageSize: 10, + wantPages: 3, + wantTotal: 25, + wantPage: 1, + wantPageSize: 10, + }, + { + name: "总数刚好整除", + items: []string{"a"}, + total: 20, + page: 2, + pageSize: 10, + wantPages: 2, + wantTotal: 20, + wantPage: 2, + wantPageSize: 10, + }, + { + name: "总数为0_pages至少为1", + items: []string{}, + total: 0, + page: 1, + pageSize: 10, + wantPages: 1, + wantTotal: 0, + wantPage: 1, + wantPageSize: 10, + }, + { + name: "单页数据", + items: []int{1, 2, 3}, + total: 3, + page: 1, + pageSize: 20, + wantPages: 1, + wantTotal: 3, + wantPage: 1, + wantPageSize: 20, + }, + { + name: "总数为1", + items: []string{"only"}, + total: 1, + page: 1, + pageSize: 10, + wantPages: 1, + wantTotal: 1, + wantPage: 1, + wantPageSize: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Paginated(c, tt.items, tt.total, tt.page, tt.pageSize) + + require.Equal(t, http.StatusOK, w.Code) + + resp, pd := parsePaginatedBody(t, w) + require.Equal(t, 0, resp.Code) + require.Equal(t, "success", resp.Message) + require.Equal(t, tt.wantTotal, pd.Total) + require.Equal(t, tt.wantPage, pd.Page) + require.Equal(t, tt.wantPageSize, pd.PageSize) + require.Equal(t, tt.wantPages, pd.Pages) + }) + } +} + +func TestPaginatedWithResult(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + items any + pagination *PaginationResult + wantTotal int64 + wantPage int + wantPageSize int + wantPages int + }{ + { + name: "正常分页结果", + items: []string{"a", "b"}, + pagination: &PaginationResult{ + Total: 50, + Page: 3, + PageSize: 10, + Pages: 5, + }, + wantTotal: 50, + wantPage: 3, + wantPageSize: 10, + wantPages: 5, + }, + { + name: "pagination为nil_使用默认值", + items: []string{}, + pagination: nil, + wantTotal: 0, + wantPage: 1, + wantPageSize: 20, + wantPages: 1, + }, + { + name: "单页结果", + items: []int{1}, + pagination: &PaginationResult{ + Total: 1, + Page: 1, + PageSize: 20, + Pages: 1, + }, + wantTotal: 1, + wantPage: 1, + wantPageSize: 20, + wantPages: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + PaginatedWithResult(c, tt.items, tt.pagination) + + require.Equal(t, http.StatusOK, w.Code) + + resp, pd := parsePaginatedBody(t, w) + require.Equal(t, 0, resp.Code) + require.Equal(t, "success", resp.Message) + require.Equal(t, tt.wantTotal, pd.Total) + require.Equal(t, tt.wantPage, pd.Page) + require.Equal(t, tt.wantPageSize, pd.PageSize) + require.Equal(t, tt.wantPages, pd.Pages) + }) + } +} + +func TestParsePagination(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + query string + wantPage int + wantPageSize int + }{ + { + name: "无参数_使用默认值", + query: "", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "仅指定page", + query: "page=3", + wantPage: 3, + wantPageSize: 20, + }, + { + name: "仅指定page_size", + query: "page_size=50", + wantPage: 1, + wantPageSize: 50, + }, + { + name: "同时指定page和page_size", + query: "page=2&page_size=30", + wantPage: 2, + wantPageSize: 30, + }, + { + name: "使用limit代替page_size", + query: "limit=15", + wantPage: 1, + wantPageSize: 15, + }, + { + name: "page_size优先于limit", + query: "page_size=25&limit=50", + wantPage: 1, + wantPageSize: 25, + }, + { + name: "page为0_使用默认值", + query: "page=0", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size超过1000_使用默认值", + query: "page_size=1001", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size恰好1000_有效", + query: "page_size=1000", + wantPage: 1, + wantPageSize: 1000, + }, + { + name: "page为非数字_使用默认值", + query: "page=abc", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size为非数字_使用默认值", + query: "page_size=xyz", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "limit为非数字_使用默认值", + query: "limit=abc", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size为0_使用默认值", + query: "page_size=0", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "limit为0_使用默认值", + query: "limit=0", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "大页码", + query: "page=999&page_size=100", + wantPage: 999, + wantPageSize: 100, + }, + { + name: "page_size为1_最小有效值", + query: "page_size=1", + wantPage: 1, + wantPageSize: 1, + }, + { + name: "混合数字和字母的page", + query: "page=12a", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "limit超过1000_使用默认值", + query: "limit=2000", + wantPage: 1, + wantPageSize: 20, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, c := newContextWithQuery(tt.query) + + page, pageSize := ParsePagination(c) + + require.Equal(t, tt.wantPage, page, "page 不符合预期") + require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期") + }) + } +} + +func Test_parseInt(t *testing.T) { + tests := []struct { + name string + input string + wantVal int + wantErr bool + }{ + { + name: "正常数字", + input: "123", + wantVal: 123, + wantErr: false, + }, + { + name: "零", + input: "0", + wantVal: 0, + wantErr: false, + }, + { + name: "单个数字", + input: "5", + wantVal: 5, + wantErr: false, + }, + { + name: "大数字", + input: "99999", + wantVal: 99999, + wantErr: false, + }, + { + name: "包含字母_返回0", + input: "abc", + wantVal: 0, + wantErr: false, + }, + { + name: "数字开头接字母_返回0", + input: "12a", + wantVal: 0, + wantErr: false, + }, + { + name: "包含负号_返回0", + input: "-1", + wantVal: 0, + wantErr: false, + }, + { + name: "包含小数点_返回0", + input: "1.5", + wantVal: 0, + wantErr: false, + }, + { + name: "包含空格_返回0", + input: "1 2", + wantVal: 0, + wantErr: false, + }, + { + name: "空字符串", + input: "", + wantVal: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := parseInt(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.wantVal, val) + }) + } +} diff --git a/backend/internal/pkg/tlsfingerprint/dialer.go b/backend/internal/pkg/tlsfingerprint/dialer.go index 42510986..4f25a34a 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer.go +++ b/backend/internal/pkg/tlsfingerprint/dialer.go @@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st "cipher_suites", len(spec.CipherSuites), "extensions", len(spec.Extensions), "compression_methods", spec.CompressionMethods, - "tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax), - "tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin)) + "tls_vers_max", spec.TLSVersMax, + "tls_vers_min", spec.TLSVersMin) if d.profile != nil { slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) @@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st return nil, fmt.Errorf("apply TLS preset: %w", err) } - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err) _ = conn.Close() return nil, fmt.Errorf("TLS handshake failed: %w", err) @@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_socks5_handshake_success", - "version", fmt.Sprintf("0x%04x", state.Version), - "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "version", state.Version, + "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) return tlsConn, nil @@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_http_proxy_handshake_success", - "version", fmt.Sprintf("0x%04x", state.Version), - "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "version", state.Version, + "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) return tlsConn, nil @@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net. // Log successful handshake details state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_handshake_success", - "version", fmt.Sprintf("0x%04x", state.Version), - "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "version", state.Version, + "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) return tlsConn, nil diff --git a/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go b/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go index eea74fcc..3f668fbe 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go +++ b/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go @@ -30,7 +30,8 @@ func skipIfExternalServiceUnavailable(t *testing.T, err error) { strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "no such host") || strings.Contains(errStr, "network is unreachable") || - strings.Contains(errStr, "timeout") { + strings.Contains(errStr, "timeout") || + strings.Contains(errStr, "deadline exceeded") { t.Skipf("skipping test: external service unavailable: %v", err) } t.Fatalf("failed to get fingerprint: %v", err) diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go index dff7570f..6d3db174 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer_test.go +++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go @@ -1,3 +1,5 @@ +//go:build unit + // Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. // // Unit tests for TLS fingerprint dialer. @@ -9,26 +11,161 @@ package tlsfingerprint import ( + "context" + "encoding/json" + "io" + "net/http" "net/url" + "os" + "strings" "testing" + "time" ) -// FingerprintResponse represents the response from tls.peet.ws/api/all. -type FingerprintResponse struct { - IP string `json:"ip"` - TLS TLSInfo `json:"tls"` - HTTP2 any `json:"http2"` +// TestDialerBasicConnection tests that the dialer can establish TLS connections. +func TestDialerBasicConnection(t *testing.T) { + skipNetworkTest(t) + + // Create a dialer with default profile + profile := &Profile{ + Name: "Test Profile", + EnableGREASE: false, + } + dialer := NewDialer(profile, nil) + + // Create HTTP client with custom TLS dialer + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + // Make a request to a known HTTPS endpoint + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } } -// TLSInfo contains TLS fingerprint details. -type TLSInfo struct { - JA3 string `json:"ja3"` - JA3Hash string `json:"ja3_hash"` - JA4 string `json:"ja4"` - PeetPrint string `json:"peetprint"` - PeetPrintHash string `json:"peetprint_hash"` - ClientRandom string `json:"client_random"` - SessionID string `json:"session_id"` +// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value. +// This test uses tls.peet.ws to verify the fingerprint. +// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) +// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) +func TestJA3Fingerprint(t *testing.T) { + skipNetworkTest(t) + + profile := &Profile{ + Name: "Claude CLI Test", + EnableGREASE: false, + } + dialer := NewDialer(profile, nil) + + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + // Use tls.peet.ws fingerprint detection API + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("failed to get fingerprint: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response: %v", err) + } + + var fpResp FingerprintResponse + if err := json.Unmarshal(body, &fpResp); err != nil { + t.Logf("Response body: %s", string(body)) + t.Fatalf("failed to parse fingerprint response: %v", err) + } + + // Log all fingerprint information + t.Logf("JA3: %s", fpResp.TLS.JA3) + t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash) + t.Logf("JA4: %s", fpResp.TLS.JA4) + t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint) + t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash) + + // Verify JA3 hash matches expected value + expectedJA3Hash := "1a28e69016765d92e3b381168d68922c" + if fpResp.TLS.JA3Hash == expectedJA3Hash { + t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash) + } else { + t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash) + } + + // Verify JA4 fingerprint + // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash] + // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP) + // The suffix _a33745022dd6_1f22a2ca17c4 should match + expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4" + if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) { + t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix) + } else { + t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix) + } + + // Verify JA4 prefix (t13d5911h1 or t13i5911h1) + // d = domain (SNI present), i = IP (no SNI) + // Since we connect to tls.peet.ws (domain), we expect 'd' + expectedJA4Prefix := "t13d5911h1" + if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) { + t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix) + } else { + // Also accept 'i' variant for IP connections + altPrefix := "t13i5911h1" + if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) { + t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix) + } else { + t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix) + } + } + + // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning) + if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") { + t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites") + } else { + t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites") + } + + // Verify extension list (should be 11 extensions including SNI) + // Expected: 0-11-10-35-16-22-23-13-43-45-51 + expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51" + if strings.Contains(fpResp.TLS.JA3, expectedExtensions) { + t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions) + } else { + t.Logf("Warning: JA3 extension list may differ") + } +} + +func skipNetworkTest(t *testing.T) { + if testing.Short() { + t.Skip("跳过网络测试(short 模式)") + } + if os.Getenv("TLSFINGERPRINT_NETWORK_TESTS") != "1" { + t.Skip("跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1)") + } } // TestDialerWithProfile tests that different profiles produce different fingerprints. @@ -158,3 +295,137 @@ func mustParseURL(rawURL string) *url.URL { } return u } + +// TestProfileExpectation defines expected fingerprint values for a profile. +type TestProfileExpectation struct { + Profile *Profile + ExpectedJA3 string // Expected JA3 hash (empty = don't check) + ExpectedJA4 string // Expected full JA4 (empty = don't check) + JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check) +} + +// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws. +// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/... +func TestAllProfiles(t *testing.T) { + skipNetworkTest(t) + + // Define all profiles to test with their expected fingerprints + // These profiles are from config.yaml gateway.tls_fingerprint.profiles + profiles := []TestProfileExpectation{ + { + // Linux x64 Node.js v22.17.1 + // Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c + // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 + Profile: &Profile{ + Name: "linux_x64_node_v22171", + EnableGREASE: false, + CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, + Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, + PointFormats: []uint8{0, 1, 2}, + }, + JA4CipherHash: "a33745022dd6", // stable part + }, + { + // MacOS arm64 Node.js v22.18.0 + // Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea + // Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406 + Profile: &Profile{ + Name: "macos_arm64_node_v22180", + EnableGREASE: false, + CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, + Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, + PointFormats: []uint8{0, 1, 2}, + }, + JA4CipherHash: "a33745022dd6", // stable part (same cipher suites) + }, + } + + for _, tc := range profiles { + tc := tc // capture range variable + t.Run(tc.Profile.Name, func(t *testing.T) { + fp := fetchFingerprint(t, tc.Profile) + if fp == nil { + return // fetchFingerprint already called t.Fatal + } + + t.Logf("Profile: %s", tc.Profile.Name) + t.Logf(" JA3: %s", fp.JA3) + t.Logf(" JA3 Hash: %s", fp.JA3Hash) + t.Logf(" JA4: %s", fp.JA4) + t.Logf(" PeetPrint: %s", fp.PeetPrint) + t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash) + + // Verify expectations + if tc.ExpectedJA3 != "" { + if fp.JA3Hash == tc.ExpectedJA3 { + t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3) + } else { + t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3) + } + } + + if tc.ExpectedJA4 != "" { + if fp.JA4 == tc.ExpectedJA4 { + t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4) + } else { + t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4) + } + } + + // Check JA4 cipher hash (stable middle part) + // JA4 format: prefix_cipherHash_extHash + if tc.JA4CipherHash != "" { + if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") { + t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash) + } else { + t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash) + } + } + }) + } +} + +// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info. +func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo { + t.Helper() + + dialer := NewDialer(profile, nil) + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + return nil + } + req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("failed to get fingerprint: %v", err) + return nil + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response: %v", err) + return nil + } + + var fpResp FingerprintResponse + if err := json.Unmarshal(body, &fpResp); err != nil { + t.Logf("Response body: %s", string(body)) + t.Fatalf("failed to parse fingerprint response: %v", err) + return nil + } + + return &fpResp.TLS +} diff --git a/backend/internal/pkg/tlsfingerprint/test_types_test.go b/backend/internal/pkg/tlsfingerprint/test_types_test.go new file mode 100644 index 00000000..2bbf2d22 --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/test_types_test.go @@ -0,0 +1,20 @@ +package tlsfingerprint + +// FingerprintResponse represents the response from tls.peet.ws/api/all. +// 共享测试类型,供 unit 和 integration 测试文件使用。 +type FingerprintResponse struct { + IP string `json:"ip"` + TLS TLSInfo `json:"tls"` + HTTP2 any `json:"http2"` +} + +// TLSInfo contains TLS fingerprint details. +type TLSInfo struct { + JA3 string `json:"ja3"` + JA3Hash string `json:"ja3_hash"` + JA4 string `json:"ja4"` + PeetPrint string `json:"peetprint"` + PeetPrintHash string `json:"peetprint_hash"` + ClientRandom string `json:"client_random"` + SessionID string `json:"session_id"` +} diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 2f6c7fe0..314a6d3c 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -78,6 +78,16 @@ type ModelStat struct { ActualCost float64 `json:"actual_cost"` // 实际扣除 } +// GroupStat represents usage statistics for a single group +type GroupStat struct { + GroupID int64 `json:"group_id"` + GroupName string `json:"group_name"` + Requests int64 `json:"requests"` + TotalTokens int64 `json:"total_tokens"` + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 +} + // UserUsageTrendPoint represents user usage trend data point type UserUsageTrendPoint struct { Date string `json:"date"` @@ -139,6 +149,7 @@ type UsageLogFilters struct { AccountID int64 GroupID int64 Model string + RequestType *int16 Stream *bool BillingType *int8 StartTime *time.Time diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 11c206d8..4aa74928 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -15,7 +15,6 @@ import ( "database/sql" "encoding/json" "errors" - "log" "strconv" "time" @@ -25,6 +24,7 @@ import ( dbgroup "github.com/Wei-Shaw/sub2api/ent/group" dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" @@ -50,11 +50,6 @@ type accountRepository struct { schedulerCache service.SchedulerCache } -type tempUnschedSnapshot struct { - until *time.Time - reason string -} - // NewAccountRepository 创建账户仓储实例。 // 这是对外暴露的构造函数,返回接口类型以便于依赖注入。 func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository { @@ -127,7 +122,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account account.CreatedAt = created.CreatedAt account.UpdatedAt = created.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { - log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err) } return nil } @@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi accountIDs = append(accountIDs, acc.ID) } - tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs) - if err != nil { - return nil, err - } - groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs) if err != nil { return nil, err @@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi if ags, ok := accountGroupsByAccount[entAcc.ID]; ok { out.AccountGroups = ags } - if snap, ok := tempUnschedMap[entAcc.ID]; ok { - out.TempUnschedulableUntil = snap.until - out.TempUnschedulableReason = snap.reason - } outByID[entAcc.ID] = out } @@ -282,6 +268,34 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID return &accounts[0], nil } +func (r *accountRepository) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { + rows, err := r.sql.QueryContext(ctx, ` + SELECT id, extra->>'crs_account_id' + FROM accounts + WHERE deleted_at IS NULL + AND extra->>'crs_account_id' IS NOT NULL + AND extra->>'crs_account_id' != '' + `) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + result := make(map[string]int64) + for rows.Next() { + var id int64 + var crsID string + if err := rows.Scan(&id, &crsID); err != nil { + return nil, err + } + result[crsID] = id + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + func (r *accountRepository) Update(ctx context.Context, account *service.Account) error { if account == nil { return nil @@ -360,7 +374,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account } account.UpdatedAt = updated.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { - log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) } if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable { r.syncSchedulerAccountSnapshot(ctx, account.ID) @@ -401,16 +415,16 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { } } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil { - log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err) } return nil } func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { - return r.ListWithFilters(ctx, params, "", "", "", "") + return r.ListWithFilters(ctx, params, "", "", "", "", 0) } -func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { +func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { q := r.client.Account.Query() if platform != "" { @@ -420,11 +434,19 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati q = q.Where(dbaccount.TypeEQ(accountType)) } if status != "" { - q = q.Where(dbaccount.StatusEQ(status)) + switch status { + case "rate_limited": + q = q.Where(dbaccount.RateLimitResetAtGT(time.Now())) + default: + q = q.Where(dbaccount.StatusEQ(status)) + } } if search != "" { q = q.Where(dbaccount.NameContainsFold(search)) } + if groupID > 0 { + q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID))) + } total, err := q.Count(ctx) if err != nil { @@ -497,7 +519,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error }, } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err) } return nil } @@ -532,7 +554,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map } payload := map[string]any{"last_used": lastUsedPayload} if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue batch last used failed: err=%v", err) } return nil } @@ -547,7 +569,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err) } r.syncSchedulerAccountSnapshot(ctx, id) return nil @@ -567,11 +589,48 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac } account, err := r.GetByID(ctx, accountID) if err != nil { - log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err) + logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err) return } if err := r.schedulerCache.SetAccount(ctx, account); err != nil { - log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err) + logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err) + } +} + +func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) { + if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 { + return + } + + uniqueIDs := make([]int64, 0, len(accountIDs)) + seen := make(map[int64]struct{}, len(accountIDs)) + for _, id := range accountIDs { + if id <= 0 { + continue + } + if _, exists := seen[id]; exists { + continue + } + seen[id] = struct{}{} + uniqueIDs = append(uniqueIDs, id) + } + if len(uniqueIDs) == 0 { + return + } + + accounts, err := r.GetByIDs(ctx, uniqueIDs) + if err != nil { + logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot read failed: count=%d err=%v", len(uniqueIDs), err) + return + } + + for _, account := range accounts { + if account == nil { + continue + } + if err := r.schedulerCache.SetAccount(ctx, account); err != nil { + logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot write failed: id=%d err=%v", account.ID, err) + } } } @@ -595,7 +654,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i } payload := buildSchedulerGroupPayload([]int64{groupID}) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err) } return nil } @@ -612,7 +671,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou } payload := buildSchedulerGroupPayload([]int64{groupID}) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err) } return nil } @@ -685,7 +744,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro } payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs)) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err) } return nil } @@ -793,54 +852,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err) - } - return nil -} - -func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { - now := time.Now().UTC() - payload := map[string]string{ - "rate_limited_at": now.Format(time.RFC3339), - "rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339), - } - raw, err := json.Marshal(payload) - if err != nil { - return err - } - - scopeKey := string(scope) - client := clientFromContext(ctx, r.client) - result, err := client.ExecContext( - ctx, - `UPDATE accounts SET - extra = jsonb_set( - jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true), - ARRAY['antigravity_quota_scopes', $1]::text[], - $2::jsonb, - true - ), - updated_at = NOW(), - last_used_at = NOW() - WHERE id = $3 AND deleted_at IS NULL`, - scopeKey, - raw, - id, - ) - if err != nil { - return err - } - - affected, err := result.RowsAffected() - if err != nil { - return err - } - if affected == 0 { - return service.ErrAccountNotFound - } - - if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err) } return nil } @@ -887,7 +899,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err) } return nil } @@ -901,7 +913,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err) } return nil } @@ -920,7 +932,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err) } r.syncSchedulerAccountSnapshot(ctx, id) return nil @@ -939,7 +951,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64 return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err) } return nil } @@ -955,7 +967,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err) } return nil } @@ -979,7 +991,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err) } return nil } @@ -1003,7 +1015,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64) return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err) } return nil } @@ -1025,7 +1037,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s // 触发调度器缓存更新(仅当窗口时间有变化时) if start != nil || end != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err) } } return nil @@ -1040,7 +1052,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err) } if !schedulable { r.syncSchedulerAccountSnapshot(ctx, id) @@ -1068,7 +1080,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti } if rows > 0 { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err) } } return rows, nil @@ -1104,7 +1116,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) } return nil } @@ -1198,7 +1210,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates if rows > 0 { payload := map[string]any{"account_ids": ids} if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bulk update failed: err=%v", err) } shouldSync := false if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) { @@ -1208,9 +1220,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates shouldSync = true } if shouldSync { - for _, id := range ids { - r.syncSchedulerAccountSnapshot(ctx, id) - } + r.syncSchedulerAccountSnapshots(ctx, ids) } } return rows, nil @@ -1302,10 +1312,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d if err != nil { return nil, err } - tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs) - if err != nil { - return nil, err - } groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs) if err != nil { return nil, err @@ -1331,10 +1337,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d if ags, ok := accountGroupsByAccount[acc.ID]; ok { out.AccountGroups = ags } - if snap, ok := tempUnschedMap[acc.ID]; ok { - out.TempUnschedulableUntil = snap.until - out.TempUnschedulableReason = snap.reason - } outAccounts = append(outAccounts, *out) } @@ -1359,48 +1361,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account { ) } -func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) { - out := make(map[int64]tempUnschedSnapshot) - if len(accountIDs) == 0 { - return out, nil - } - - rows, err := r.sql.QueryContext(ctx, ` - SELECT id, temp_unschedulable_until, temp_unschedulable_reason - FROM accounts - WHERE id = ANY($1) - `, pq.Array(accountIDs)) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - - for rows.Next() { - var id int64 - var until sql.NullTime - var reason sql.NullString - if err := rows.Scan(&id, &until, &reason); err != nil { - return nil, err - } - var untilPtr *time.Time - if until.Valid { - tmp := until.Time - untilPtr = &tmp - } - if reason.Valid { - out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String} - } else { - out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""} - } - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return out, nil -} - func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) { proxyMap := make(map[int64]*service.Proxy) if len(proxyIDs) == 0 { @@ -1511,31 +1471,33 @@ func accountEntityToService(m *dbent.Account) *service.Account { rateMultiplier := m.RateMultiplier return &service.Account{ - ID: m.ID, - Name: m.Name, - Notes: m.Notes, - Platform: m.Platform, - Type: m.Type, - Credentials: copyJSONMap(m.Credentials), - Extra: copyJSONMap(m.Extra), - ProxyID: m.ProxyID, - Concurrency: m.Concurrency, - Priority: m.Priority, - RateMultiplier: &rateMultiplier, - Status: m.Status, - ErrorMessage: derefString(m.ErrorMessage), - LastUsedAt: m.LastUsedAt, - ExpiresAt: m.ExpiresAt, - AutoPauseOnExpired: m.AutoPauseOnExpired, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, - Schedulable: m.Schedulable, - RateLimitedAt: m.RateLimitedAt, - RateLimitResetAt: m.RateLimitResetAt, - OverloadUntil: m.OverloadUntil, - SessionWindowStart: m.SessionWindowStart, - SessionWindowEnd: m.SessionWindowEnd, - SessionWindowStatus: derefString(m.SessionWindowStatus), + ID: m.ID, + Name: m.Name, + Notes: m.Notes, + Platform: m.Platform, + Type: m.Type, + Credentials: copyJSONMap(m.Credentials), + Extra: copyJSONMap(m.Extra), + ProxyID: m.ProxyID, + Concurrency: m.Concurrency, + Priority: m.Priority, + RateMultiplier: &rateMultiplier, + Status: m.Status, + ErrorMessage: derefString(m.ErrorMessage), + LastUsedAt: m.LastUsedAt, + ExpiresAt: m.ExpiresAt, + AutoPauseOnExpired: m.AutoPauseOnExpired, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + Schedulable: m.Schedulable, + RateLimitedAt: m.RateLimitedAt, + RateLimitResetAt: m.RateLimitResetAt, + OverloadUntil: m.OverloadUntil, + TempUnschedulableUntil: m.TempUnschedulableUntil, + TempUnschedulableReason: derefString(m.TempUnschedulableReason), + SessionWindowStart: m.SessionWindowStart, + SessionWindowEnd: m.SessionWindowEnd, + SessionWindowStatus: derefString(m.SessionWindowStatus), } } @@ -1571,3 +1533,64 @@ func joinClauses(clauses []string, sep string) string { func itoa(v int) string { return strconv.Itoa(v) } + +// FindByExtraField 根据 extra 字段中的键值对查找账号。 +// 该方法限定 platform='sora',避免误查询其他平台的账号。 +// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。 +// +// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。 +// +// FindByExtraField finds accounts by key-value pairs in the extra field. +// Limited to platform='sora' to avoid querying accounts from other platforms. +// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index). +// +// Use case: Finding Sora accounts linked via linked_openai_account_id. +func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { + accounts, err := r.client.Account.Query(). + Where( + dbaccount.PlatformEQ("sora"), // 限定平台为 sora + dbaccount.DeletedAtIsNil(), + func(s *entsql.Selector) { + path := sqljson.Path(key) + switch v := value.(type) { + case string: + preds := []*entsql.Predicate{sqljson.ValueEQ(dbaccount.FieldExtra, v, path)} + if parsed, err := strconv.ParseInt(v, 10, 64); err == nil { + preds = append(preds, sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path)) + } + if len(preds) == 1 { + s.Where(preds[0]) + } else { + s.Where(entsql.Or(preds...)) + } + case int: + s.Where(entsql.Or( + sqljson.ValueEQ(dbaccount.FieldExtra, v, path), + sqljson.ValueEQ(dbaccount.FieldExtra, strconv.Itoa(v), path), + )) + case int64: + s.Where(entsql.Or( + sqljson.ValueEQ(dbaccount.FieldExtra, v, path), + sqljson.ValueEQ(dbaccount.FieldExtra, strconv.FormatInt(v, 10), path), + )) + case json.Number: + if parsed, err := v.Int64(); err == nil { + s.Where(entsql.Or( + sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path), + sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path), + )) + } else { + s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path)) + } + default: + s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, value, path)) + } + }, + ). + All(ctx) + if err != nil { + return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil) + } + + return r.accountsToService(ctx, accounts) +} diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index a054b6d6..fd48a5d4 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { tt.setup(client) - accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search) + accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0) s.Require().NoError(err) s.Require().Len(accounts, tt.wantCount) if tt.validate != nil { @@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { s.Require().Len(got.Groups, 1, "expected Groups to be populated") s.Require().Equal(group.ID, got.Groups[0].ID) - accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc") + accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0) s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total) s.Require().Len(accounts, 1) @@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() { s.Require().Nil(got.OverloadUntil) } +func (s *AccountRepoSuite) TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs() { + acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-1"}) + acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-2"}) + + until := time.Now().Add(15 * time.Minute).UTC().Truncate(time.Second) + reason := `{"rule":"429","matched_keyword":"too many requests"}` + s.Require().NoError(s.repo.SetTempUnschedulable(s.ctx, acc1.ID, until, reason)) + + gotByID, err := s.repo.GetByID(s.ctx, acc1.ID) + s.Require().NoError(err) + s.Require().NotNil(gotByID.TempUnschedulableUntil) + s.Require().WithinDuration(until, *gotByID.TempUnschedulableUntil, time.Second) + s.Require().Equal(reason, gotByID.TempUnschedulableReason) + + gotByIDs, err := s.repo.GetByIDs(s.ctx, []int64{acc2.ID, acc1.ID}) + s.Require().NoError(err) + s.Require().Len(gotByIDs, 2) + s.Require().Equal(acc2.ID, gotByIDs[0].ID) + s.Require().Nil(gotByIDs[0].TempUnschedulableUntil) + s.Require().Equal("", gotByIDs[0].TempUnschedulableReason) + s.Require().Equal(acc1.ID, gotByIDs[1].ID) + s.Require().NotNil(gotByIDs[1].TempUnschedulableUntil) + s.Require().WithinDuration(until, *gotByIDs[1].TempUnschedulableUntil, time.Second) + s.Require().Equal(reason, gotByIDs[1].TempUnschedulableReason) + + s.Require().NoError(s.repo.ClearTempUnschedulable(s.ctx, acc1.ID)) + cleared, err := s.repo.GetByID(s.ctx, acc1.ID) + s.Require().NoError(err) + s.Require().Nil(cleared.TempUnschedulableUntil) + s.Require().Equal("", cleared.TempUnschedulableReason) +} + // --- UpdateLastUsed --- func (s *AccountRepoSuite) TestUpdateLastUsed() { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index c0cfd256..b9ce60a5 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -34,6 +34,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro SetName(key.Name). SetStatus(key.Status). SetNillableGroupID(key.GroupID). + SetNillableLastUsedAt(key.LastUsedAt). SetQuota(key.Quota). SetQuotaUsed(key.QuotaUsed). SetNillableExpiresAt(key.ExpiresAt) @@ -48,6 +49,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro created, err := builder.Save(ctx) if err == nil { key.ID = created.ID + key.LastUsedAt = created.LastUsedAt key.CreatedAt = created.CreatedAt key.UpdatedAt = created.UpdatedAt } @@ -140,6 +142,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, + group.FieldSoraImagePrice360, + group.FieldSoraImagePrice540, + group.FieldSoraVideoPricePerRequest, + group.FieldSoraVideoPricePerRequestHd, group.FieldClaudeCodeOnly, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, @@ -165,8 +171,9 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro // 则会更新已删除的记录。 // 这里选择 Update().Where(),确保只有未软删除记录能被更新。 // 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。 + client := clientFromContext(ctx, r.client) now := time.Now() - builder := r.client.APIKey.Update(). + builder := client.APIKey.Update(). Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()). SetName(key.Name). SetStatus(key.Status). @@ -375,36 +382,34 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) return keys, nil } -// IncrementQuotaUsed atomically increments the quota_used field and returns the new value +// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值 func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { - // Use raw SQL for atomic increment to avoid race conditions - // First get current value - m, err := r.activeQuery(). - Where(apikey.IDEQ(id)). - Select(apikey.FieldQuotaUsed). - Only(ctx) + updated, err := r.client.APIKey.UpdateOneID(id). + Where(apikey.DeletedAtIsNil()). + AddQuotaUsed(amount). + Save(ctx) if err != nil { if dbent.IsNotFound(err) { return 0, service.ErrAPIKeyNotFound } return 0, err } + return updated.QuotaUsed, nil +} - newValue := m.QuotaUsed + amount - - // Update with new value +func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { affected, err := r.client.APIKey.Update(). Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). - SetQuotaUsed(newValue). + SetLastUsedAt(usedAt). + SetUpdatedAt(usedAt). Save(ctx) if err != nil { - return 0, err + return err } if affected == 0 { - return 0, service.ErrAPIKeyNotFound + return service.ErrAPIKeyNotFound } - - return newValue, nil + return nil } func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { @@ -419,6 +424,7 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { Status: m.Status, IPWhitelist: m.IPWhitelist, IPBlacklist: m.IPBlacklist, + LastUsedAt: m.LastUsedAt, CreatedAt: m.CreatedAt, UpdatedAt: m.UpdatedAt, GroupID: m.GroupID, @@ -440,20 +446,22 @@ func userEntityToService(u *dbent.User) *service.User { return nil } return &service.User{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Notes: u.Notes, - PasswordHash: u.PasswordHash, - Role: u.Role, - Balance: u.Balance, - Concurrency: u.Concurrency, - Status: u.Status, - TotpSecretEncrypted: u.TotpSecretEncrypted, - TotpEnabled: u.TotpEnabled, - TotpEnabledAt: u.TotpEnabledAt, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, + ID: u.ID, + Email: u.Email, + Username: u.Username, + Notes: u.Notes, + PasswordHash: u.PasswordHash, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + SoraStorageQuotaBytes: u.SoraStorageQuotaBytes, + SoraStorageUsedBytes: u.SoraStorageUsedBytes, + TotpSecretEncrypted: u.TotpSecretEncrypted, + TotpEnabled: u.TotpEnabled, + TotpEnabledAt: u.TotpEnabledAt, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, } } @@ -477,6 +485,11 @@ func groupEntityToService(g *dbent.Group) *service.Group { ImagePrice1K: g.ImagePrice1k, ImagePrice2K: g.ImagePrice2k, ImagePrice4K: g.ImagePrice4k, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, + SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, DefaultValidityDays: g.DefaultValidityDays, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, @@ -485,6 +498,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { ModelRoutingEnabled: g.ModelRoutingEnabled, MCPXMLInject: g.McpXMLInject, SupportedModelScopes: g.SupportedModelScopes, + SortOrder: g.SortOrder, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 879a0576..303d7126 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -4,11 +4,14 @@ package repository import ( "context" + "sync" "testing" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group s.Require().NoError(s.repo.Create(s.ctx, k), "create api key") return k } + +// --- IncrementQuotaUsed --- + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() { + user := s.mustCreateUser("incr-basic@test.com") + key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil) + + newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5) + s.Require().NoError(err, "IncrementQuotaUsed") + s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5") + + newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5) + s.Require().NoError(err, "IncrementQuotaUsed second") + s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0") +} + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() { + _, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0) + s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound") +} + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() { + user := s.mustCreateUser("incr-deleted@test.com") + key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil) + + s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete") + + _, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0) + s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound") +} + +// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。 +// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。 +func TestIncrementQuotaUsed_Concurrent(t *testing.T) { + client := testEntClient(t) + repo := NewAPIKeyRepository(client).(*apiKeyRepository) + ctx := context.Background() + + // 创建测试用户和 API Key + u, err := client.User.Create(). + SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com"). + SetPasswordHash("hash"). + SetStatus(service.StatusActive). + SetRole(service.RoleUser). + Save(ctx) + require.NoError(t, err, "create user") + + k := &service.APIKey{ + UserID: u.ID, + Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano), + Name: "Concurrent", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, k), "create api key") + t.Cleanup(func() { + _ = client.APIKey.DeleteOneID(k.ID).Exec(ctx) + _ = client.User.DeleteOneID(u.ID).Exec(ctx) + }) + + // 10 个 goroutine 各递增 1.0,总计应为 10.0 + const goroutines = 10 + const increment = 1.0 + var wg sync.WaitGroup + errs := make([]error, goroutines) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment) + }(i) + } + wg.Wait() + + for i, e := range errs { + require.NoError(t, e, "goroutine %d failed", i) + } + + // 验证最终结果 + got, err := repo.GetByID(ctx, k.ID) + require.NoError(t, err, "GetByID") + require.Equal(t, float64(goroutines)*increment, got.QuotaUsed, + "并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed) +} diff --git a/backend/internal/repository/api_key_repo_last_used_unit_test.go b/backend/internal/repository/api_key_repo_last_used_unit_test.go new file mode 100644 index 00000000..7c6e2850 --- /dev/null +++ b/backend/internal/repository/api_key_repo_last_used_unit_test.go @@ -0,0 +1,156 @@ +package repository + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAPIKeyRepoSQLite(t *testing.T) (*apiKeyRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:api_key_repo_last_used?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return &apiKeyRepository{client: client}, client +} + +func mustCreateAPIKeyRepoUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *service.User { + t.Helper() + u, err := client.User.Create(). + SetEmail(email). + SetPasswordHash("test-password-hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + return userEntityToService(u) +} + +func TestAPIKeyRepository_CreateWithLastUsedAt(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "create-last-used@test.com") + + lastUsed := time.Now().UTC().Add(-time.Hour).Truncate(time.Second) + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-create-last-used", + Name: "CreateWithLastUsed", + Status: service.StatusActive, + LastUsedAt: &lastUsed, + } + + require.NoError(t, repo.Create(ctx, key)) + require.NotNil(t, key.LastUsedAt) + require.WithinDuration(t, lastUsed, *key.LastUsedAt, time.Second) + + got, err := repo.GetByID(ctx, key.ID) + require.NoError(t, err) + require.NotNil(t, got.LastUsedAt) + require.WithinDuration(t, lastUsed, *got.LastUsedAt, time.Second) +} + +func TestAPIKeyRepository_UpdateLastUsed(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "update-last-used@test.com") + + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-update-last-used", + Name: "UpdateLastUsed", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, key)) + + before, err := repo.GetByID(ctx, key.ID) + require.NoError(t, err) + require.Nil(t, before.LastUsedAt) + + target := time.Now().UTC().Add(2 * time.Minute).Truncate(time.Second) + require.NoError(t, repo.UpdateLastUsed(ctx, key.ID, target)) + + after, err := repo.GetByID(ctx, key.ID) + require.NoError(t, err) + require.NotNil(t, after.LastUsedAt) + require.WithinDuration(t, target, *after.LastUsedAt, time.Second) + require.WithinDuration(t, target, after.UpdatedAt, time.Second) +} + +func TestAPIKeyRepository_UpdateLastUsedDeletedKey(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "deleted-last-used@test.com") + + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-update-last-used-deleted", + Name: "UpdateLastUsedDeleted", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, key)) + require.NoError(t, repo.Delete(ctx, key.ID)) + + err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC()) + require.ErrorIs(t, err, service.ErrAPIKeyNotFound) +} + +func TestAPIKeyRepository_UpdateLastUsedDBError(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "db-error-last-used@test.com") + + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-update-last-used-db-error", + Name: "UpdateLastUsedDBError", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, key)) + + require.NoError(t, client.Close()) + err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC()) + require.Error(t, err) +} + +func TestAPIKeyRepository_CreateDuplicateKey(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "duplicate-key@test.com") + + first := &service.APIKey{ + UserID: user.ID, + Key: "sk-duplicate", + Name: "first", + Status: service.StatusActive, + } + second := &service.APIKey{ + UserID: user.ID, + Key: "sk-duplicate", + Name: "second", + Status: service.StatusActive, + } + + require.NoError(t, repo.Create(ctx, first)) + err := repo.Create(ctx, second) + require.ErrorIs(t, err, service.ErrAPIKeyExists) +} diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index ac5803a1..e753e1b8 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log" + "math/rand/v2" "strconv" "time" @@ -16,8 +17,19 @@ const ( billingBalanceKeyPrefix = "billing:balance:" billingSubKeyPrefix = "billing:sub:" billingCacheTTL = 5 * time.Minute + billingCacheJitter = 30 * time.Second ) +// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩 +func jitteredTTL() time.Duration { + // 只做“减法抖动”,确保实际 TTL 不会超过 billingCacheTTL(避免上界预期被打破)。 + if billingCacheJitter <= 0 { + return billingCacheTTL + } + jitter := time.Duration(rand.IntN(int(billingCacheJitter))) + return billingCacheTTL - jitter +} + // billingBalanceKey generates the Redis key for user balance cache. func billingBalanceKey(userID int64) string { return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) @@ -82,14 +94,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6 func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error { key := billingBalanceKey(userID) - return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err() + return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err() } func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { key := billingBalanceKey(userID) - _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result() + _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) + return err } return nil } @@ -163,16 +176,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID pipe := c.rdb.Pipeline() pipe.HSet(ctx, key, fields) - pipe.Expire(ctx, key, billingCacheTTL) + pipe.Expire(ctx, key, jitteredTTL()) _, err := pipe.Exec(ctx) return err } func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { key := billingSubKey(userID, groupID) - _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result() + _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err) + return err } return nil } diff --git a/backend/internal/repository/billing_cache_integration_test.go b/backend/internal/repository/billing_cache_integration_test.go index 2f7c69a7..4b7377b1 100644 --- a/backend/internal/repository/billing_cache_integration_test.go +++ b/backend/internal/repository/billing_cache_integration_test.go @@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() { } } +// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复: +// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。 +func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() { + tests := []struct { + name string + fn func(ctx context.Context, cache service.BillingCache) + expectErr bool + }{ + { + name: "key_not_exists_returns_nil", + fn: func(ctx context.Context, cache service.BillingCache) { + // key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误 + err := cache.DeductUserBalance(ctx, 99999, 1.0) + require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil") + }, + }, + { + name: "existing_key_deducts_successfully", + fn: func(ctx context.Context, cache service.BillingCache) { + require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0)) + err := cache.DeductUserBalance(ctx, 200, 10.0) + require.NoError(s.T(), err, "DeductUserBalance should succeed") + + bal, err := cache.GetUserBalance(ctx, 200) + require.NoError(s.T(), err) + require.Equal(s.T(), 40.0, bal, "余额应为 40.0") + }, + }, + { + name: "cancelled_context_propagates_error", + fn: func(ctx context.Context, cache service.BillingCache) { + require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0)) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() // 立即取消 + + err := cache.DeductUserBalance(cancelCtx, 201, 10.0) + require.Error(s.T(), err, "cancelled context should propagate error") + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + tt.fn(ctx, cache) + }) + } +} + +// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复: +// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。 +func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() { + s.Run("key_not_exists_returns_nil", func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + + err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0) + require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil") + }) + + s.Run("cancelled_context_propagates_error", func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + + data := &service.SubscriptionCacheData{ + Status: "active", + ExpiresAt: time.Now().Add(1 * time.Hour), + Version: 1, + } + require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data)) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + + err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0) + require.Error(s.T(), err, "cancelled context should propagate error") + }) +} + func TestBillingCacheSuite(t *testing.T) { suite.Run(t, new(BillingCacheSuite)) } diff --git a/backend/internal/repository/billing_cache_jitter_test.go b/backend/internal/repository/billing_cache_jitter_test.go new file mode 100644 index 00000000..ba4f2873 --- /dev/null +++ b/backend/internal/repository/billing_cache_jitter_test.go @@ -0,0 +1,82 @@ +package repository + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Task 6.1 验证: math/rand/v2 迁移后 jitteredTTL 行为正确 --- + +func TestJitteredTTL_WithinExpectedRange(t *testing.T) { + // jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter) + // 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内 + lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s + upperBound := billingCacheTTL // 5min + + for i := 0; i < 200; i++ { + ttl := jitteredTTL() + assert.GreaterOrEqual(t, int64(ttl), int64(lowerBound), + "TTL 不应低于 %v,实际得到 %v", lowerBound, ttl) + assert.LessOrEqual(t, int64(ttl), int64(upperBound), + "TTL 不应超过 %v(上界不变保证),实际得到 %v", upperBound, ttl) + } +} + +func TestJitteredTTL_NeverExceedsBase(t *testing.T) { + // 关键安全性测试:jitteredTTL 使用减法抖动,确保永远不超过 billingCacheTTL + for i := 0; i < 500; i++ { + ttl := jitteredTTL() + assert.LessOrEqual(t, int64(ttl), int64(billingCacheTTL), + "jitteredTTL 不应超过基础 TTL(上界预期不被打破)") + } +} + +func TestJitteredTTL_HasVariance(t *testing.T) { + // 验证抖动确实产生了不同的值 + results := make(map[time.Duration]bool) + for i := 0; i < 100; i++ { + ttl := jitteredTTL() + results[ttl] = true + } + + require.Greater(t, len(results), 1, + "jitteredTTL 应产生不同的值(抖动生效),但 100 次调用结果全部相同") +} + +func TestJitteredTTL_AverageNearCenter(t *testing.T) { + // 验证平均值大约在抖动范围中间 + var sum time.Duration + runs := 1000 + for i := 0; i < runs; i++ { + sum += jitteredTTL() + } + + avg := sum / time.Duration(runs) + expectedCenter := billingCacheTTL - billingCacheJitter/2 // 4min45s + + // 允许 ±5s 的误差 + tolerance := 5 * time.Second + assert.InDelta(t, float64(expectedCenter), float64(avg), float64(tolerance), + "平均 TTL 应接近抖动范围中心 %v", expectedCenter) +} + +func TestBillingKeyGeneration(t *testing.T) { + t.Run("balance_key", func(t *testing.T) { + key := billingBalanceKey(12345) + assert.Equal(t, "billing:balance:12345", key) + }) + + t.Run("sub_key", func(t *testing.T) { + key := billingSubKey(100, 200) + assert.Equal(t, "billing:sub:100:200", key) + }) +} + +func BenchmarkJitteredTTL(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = jitteredTTL() + } +} diff --git a/backend/internal/repository/billing_cache_test.go b/backend/internal/repository/billing_cache_test.go index 7d3fd19d..2de1da87 100644 --- a/backend/internal/repository/billing_cache_test.go +++ b/backend/internal/repository/billing_cache_test.go @@ -5,6 +5,7 @@ package repository import ( "math" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) { }) } } + +func TestJitteredTTL(t *testing.T) { + const ( + minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s + maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s + ) + + for i := 0; i < 200; i++ { + ttl := jitteredTTL() + require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl) + require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl) + } +} + +func TestJitteredTTL_HasVariation(t *testing.T) { + // 多次调用应该产生不同的值(验证抖动存在) + seen := make(map[time.Duration]struct{}, 50) + for i := 0; i < 50; i++ { + seen[jitteredTTL()] = struct{}{} + } + // 50 次调用中应该至少有 2 个不同的值 + require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值") +} diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index fc0d2918..b754bd55 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -4,13 +4,14 @@ import ( "context" "encoding/json" "fmt" - "log" "net/http" "net/url" "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/util/logredact" @@ -28,11 +29,14 @@ func NewClaudeOAuthClient() service.ClaudeOAuthClient { type claudeOAuthService struct { baseURL string tokenURL string - clientFactory func(proxyURL string) *req.Client + clientFactory func(proxyURL string) (*req.Client, error) } func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return "", fmt.Errorf("create HTTP client: %w", err) + } var orgs []struct { UUID string `json:"uuid"` @@ -41,7 +45,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey } targetURL := s.baseURL + "/api/organizations" - log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1: Getting organization UUID from %s", targetURL) resp, err := client.R(). SetContext(ctx). @@ -53,11 +57,11 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey Get(targetURL) if err != nil { - log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 FAILED - Request error: %v", err) return "", fmt.Errorf("request failed: %w", err) } - log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 Response - Status: %d", resp.StatusCode) if !resp.IsSuccessState() { return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String()) @@ -69,26 +73,29 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey // 如果只有一个组织,直接使用 if len(orgs) == 1 { - log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) return orgs[0].UUID, nil } // 如果有多个组织,优先选择 raven_type 为 "team" 的组织 for _, org := range orgs { if org.RavenType != nil && *org.RavenType == "team" { - log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s", + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s", org.UUID, org.Name, *org.RavenType) return org.UUID, nil } } // 如果没有 team 类型的组织,使用第一个 - log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) return orgs[0].UUID, nil } func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return "", fmt.Errorf("create HTTP client: %w", err) + } authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID) @@ -103,9 +110,9 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe "code_challenge_method": "S256", } - log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2: Getting authorization code from %s", authURL) reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody)) - log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) var result struct { RedirectURI string `json:"redirect_uri"` @@ -128,11 +135,11 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe Post(authURL) if err != nil { - log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 FAILED - Request error: %v", err) return "", fmt.Errorf("request failed: %w", err) } - log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) if !resp.IsSuccessState() { return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String()) @@ -160,12 +167,15 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe fullCode = authCode + "#" + responseState } - log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code") + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 SUCCESS - Got authorization code") return fullCode, nil } func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } // Parse code which may contain state in format "authCode#state" authCode := code @@ -192,9 +202,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds } - log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody)) - log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) var tokenResp oauth.TokenResponse @@ -208,22 +218,25 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod Post(s.tokenURL) if err != nil { - log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 FAILED - Request error: %v", err) return nil, fmt.Errorf("request failed: %w", err) } - log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) if !resp.IsSuccessState() { return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String()) } - log.Printf("[OAuth] Step 3 SUCCESS - Got access token") + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 SUCCESS - Got access token") return &tokenResp, nil } func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } reqBody := map[string]any{ "grant_type": "refresh_token", @@ -253,16 +266,20 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro return &tokenResp, nil } -func createReqClient(proxyURL string) *req.Client { +func createReqClient(proxyURL string) (*req.Client, error) { // 禁用 CookieJar,确保每次授权都是干净的会话 client := req.C(). SetTimeout(60 * time.Second). ImpersonateChrome(). SetCookieJar(nil) // 禁用 CookieJar - if strings.TrimSpace(proxyURL) != "" { - client.SetProxyURL(strings.TrimSpace(proxyURL)) + trimmed, _, err := proxyurl.Parse(proxyURL) + if err != nil { + return nil, err + } + if trimmed != "" { + client.SetProxyURL(trimmed) } - return client + return client, nil } diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go index 7395c6d8..c6383033 100644 --- a/backend/internal/repository/claude_oauth_service_test.go +++ b/backend/internal/repository/claude_oauth_service_test.go @@ -91,7 +91,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.baseURL = "http://in-process" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "") @@ -169,7 +169,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.baseURL = "http://in-process" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "") @@ -276,7 +276,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.tokenURL = "http://in-process/token" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken) @@ -372,7 +372,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.tokenURL = "http://in-process/token" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } resp, err := s.client.RefreshToken(context.Background(), "rt", "") diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 1198f472..f6054828 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -83,7 +83,7 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se AllowPrivateHosts: s.allowPrivateHosts, }) if err != nil { - client = &http.Client{Timeout: 30 * time.Second} + return nil, fmt.Errorf("create http client failed: %w", err) } resp, err = client.Do(req) diff --git a/backend/internal/repository/claude_usage_service_test.go b/backend/internal/repository/claude_usage_service_test.go index 2e10f3e5..cbd0b6d3 100644 --- a/backend/internal/repository/claude_usage_service_test.go +++ b/backend/internal/repository/claude_usage_service_test.go @@ -50,7 +50,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { allowPrivateHosts: true, } - resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") + resp, err := s.fetcher.FetchUsage(context.Background(), "at", "") require.NoError(s.T(), err, "FetchUsage") require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch") require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch") @@ -112,6 +112,17 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { require.Error(s.T(), err, "expected error for cancelled context") } +func (s *ClaudeUsageServiceSuite) TestFetchUsage_InvalidProxyReturnsError() { + s.fetcher = &claudeUsageService{ + usageURL: "http://example.com", + allowPrivateHosts: true, + } + + _, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "create http client failed") +} + func TestClaudeUsageServiceSuite(t *testing.T) { suite.Run(t, new(ClaudeUsageServiceSuite)) } diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index cc0c6db5..a2552715 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -147,100 +147,6 @@ var ( return 1 `) - // getAccountsLoadBatchScript - batch load query with expired slot cleanup - // ARGV[1] = slot TTL (seconds) - // ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ... - getAccountsLoadBatchScript = redis.NewScript(` - local result = {} - local slotTTL = tonumber(ARGV[1]) - - -- Get current server time - local timeResult = redis.call('TIME') - local nowSeconds = tonumber(timeResult[1]) - local cutoffTime = nowSeconds - slotTTL - - local i = 2 - while i <= #ARGV do - local accountID = ARGV[i] - local maxConcurrency = tonumber(ARGV[i + 1]) - - local slotKey = 'concurrency:account:' .. accountID - - -- Clean up expired slots before counting - redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime) - local currentConcurrency = redis.call('ZCARD', slotKey) - - local waitKey = 'wait:account:' .. accountID - local waitingCount = redis.call('GET', waitKey) - if waitingCount == false then - waitingCount = 0 - else - waitingCount = tonumber(waitingCount) - end - - local loadRate = 0 - if maxConcurrency > 0 then - loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency) - end - - table.insert(result, accountID) - table.insert(result, currentConcurrency) - table.insert(result, waitingCount) - table.insert(result, loadRate) - - i = i + 2 - end - - return result - `) - - // getUsersLoadBatchScript - batch load query for users with expired slot cleanup - // ARGV[1] = slot TTL (seconds) - // ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ... - getUsersLoadBatchScript = redis.NewScript(` - local result = {} - local slotTTL = tonumber(ARGV[1]) - - -- Get current server time - local timeResult = redis.call('TIME') - local nowSeconds = tonumber(timeResult[1]) - local cutoffTime = nowSeconds - slotTTL - - local i = 2 - while i <= #ARGV do - local userID = ARGV[i] - local maxConcurrency = tonumber(ARGV[i + 1]) - - local slotKey = 'concurrency:user:' .. userID - - -- Clean up expired slots before counting - redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime) - local currentConcurrency = redis.call('ZCARD', slotKey) - - local waitKey = 'concurrency:wait:' .. userID - local waitingCount = redis.call('GET', waitKey) - if waitingCount == false then - waitingCount = 0 - else - waitingCount = tonumber(waitingCount) - end - - local loadRate = 0 - if maxConcurrency > 0 then - loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency) - end - - table.insert(result, userID) - table.insert(result, currentConcurrency) - table.insert(result, waitingCount) - table.insert(result, loadRate) - - i = i + 2 - end - - return result - `) - // cleanupExpiredSlotsScript - remove expired slots // KEYS[1] = concurrency:account:{accountID} // ARGV[1] = TTL (seconds) @@ -321,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID return result, nil } +func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + if len(accountIDs) == 0 { + return map[int64]int{}, nil + } + + now, err := c.rdb.Time(ctx).Result() + if err != nil { + return nil, fmt.Errorf("redis TIME: %w", err) + } + cutoffTime := now.Unix() - int64(c.slotTTLSeconds) + + pipe := c.rdb.Pipeline() + type accountCmd struct { + accountID int64 + zcardCmd *redis.IntCmd + } + cmds := make([]accountCmd, 0, len(accountIDs)) + for _, accountID := range accountIDs { + slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10) + pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) + cmds = append(cmds, accountCmd{ + accountID: accountID, + zcardCmd: pipe.ZCard(ctx, slotKey), + }) + } + + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("pipeline exec: %w", err) + } + + result := make(map[int64]int, len(accountIDs)) + for _, cmd := range cmds { + result[cmd.accountID] = int(cmd.zcardCmd.Val()) + } + return result, nil +} + // User slot operations func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { @@ -399,29 +342,53 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts [] return map[int64]*service.AccountLoadInfo{}, nil } - args := []any{c.slotTTLSeconds} - for _, acc := range accounts { - args = append(args, acc.ID, acc.MaxConcurrency) - } - - result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() + // 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster(Lua 内动态拼 key 会 CROSSSLOT)。 + // 每个账号执行 3 个命令:ZREMRANGEBYSCORE(清理过期)、ZCARD(并发数)、GET(等待数)。 + now, err := c.rdb.Time(ctx).Result() if err != nil { - return nil, err + return nil, fmt.Errorf("redis TIME: %w", err) + } + cutoffTime := now.Unix() - int64(c.slotTTLSeconds) + + pipe := c.rdb.Pipeline() + + type accountCmds struct { + id int64 + maxConcurrency int + zcardCmd *redis.IntCmd + getCmd *redis.StringCmd + } + cmds := make([]accountCmds, 0, len(accounts)) + for _, acc := range accounts { + slotKey := accountSlotKeyPrefix + strconv.FormatInt(acc.ID, 10) + waitKey := accountWaitKeyPrefix + strconv.FormatInt(acc.ID, 10) + pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) + ac := accountCmds{ + id: acc.ID, + maxConcurrency: acc.MaxConcurrency, + zcardCmd: pipe.ZCard(ctx, slotKey), + getCmd: pipe.Get(ctx, waitKey), + } + cmds = append(cmds, ac) } - loadMap := make(map[int64]*service.AccountLoadInfo) - for i := 0; i < len(result); i += 4 { - if i+3 >= len(result) { - break + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("pipeline exec: %w", err) + } + + loadMap := make(map[int64]*service.AccountLoadInfo, len(accounts)) + for _, ac := range cmds { + currentConcurrency := int(ac.zcardCmd.Val()) + waitingCount := 0 + if v, err := ac.getCmd.Int(); err == nil { + waitingCount = v } - - accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) - currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) - waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) - loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) - - loadMap[accountID] = &service.AccountLoadInfo{ - AccountID: accountID, + loadRate := 0 + if ac.maxConcurrency > 0 { + loadRate = (currentConcurrency + waitingCount) * 100 / ac.maxConcurrency + } + loadMap[ac.id] = &service.AccountLoadInfo{ + AccountID: ac.id, CurrentConcurrency: currentConcurrency, WaitingCount: waitingCount, LoadRate: loadRate, @@ -436,29 +403,52 @@ func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []servic return map[int64]*service.UserLoadInfo{}, nil } - args := []any{c.slotTTLSeconds} - for _, u := range users { - args = append(args, u.ID, u.MaxConcurrency) - } - - result, err := getUsersLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() + // 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster。 + now, err := c.rdb.Time(ctx).Result() if err != nil { - return nil, err + return nil, fmt.Errorf("redis TIME: %w", err) + } + cutoffTime := now.Unix() - int64(c.slotTTLSeconds) + + pipe := c.rdb.Pipeline() + + type userCmds struct { + id int64 + maxConcurrency int + zcardCmd *redis.IntCmd + getCmd *redis.StringCmd + } + cmds := make([]userCmds, 0, len(users)) + for _, u := range users { + slotKey := userSlotKeyPrefix + strconv.FormatInt(u.ID, 10) + waitKey := waitQueueKeyPrefix + strconv.FormatInt(u.ID, 10) + pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) + uc := userCmds{ + id: u.ID, + maxConcurrency: u.MaxConcurrency, + zcardCmd: pipe.ZCard(ctx, slotKey), + getCmd: pipe.Get(ctx, waitKey), + } + cmds = append(cmds, uc) } - loadMap := make(map[int64]*service.UserLoadInfo) - for i := 0; i < len(result); i += 4 { - if i+3 >= len(result) { - break + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("pipeline exec: %w", err) + } + + loadMap := make(map[int64]*service.UserLoadInfo, len(users)) + for _, uc := range cmds { + currentConcurrency := int(uc.zcardCmd.Val()) + waitingCount := 0 + if v, err := uc.getCmd.Int(); err == nil { + waitingCount = v } - - userID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) - currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) - waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) - loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) - - loadMap[userID] = &service.UserLoadInfo{ - UserID: userID, + loadRate := 0 + if uc.maxConcurrency > 0 { + loadRate = (currentConcurrency + waitingCount) * 100 / uc.maxConcurrency + } + loadMap[uc.id] = &service.UserLoadInfo{ + UserID: uc.id, CurrentConcurrency: currentConcurrency, WaitingCount: waitingCount, LoadRate: loadRate, diff --git a/backend/internal/repository/ent.go b/backend/internal/repository/ent.go index d7d574e8..5f3f5a84 100644 --- a/backend/internal/repository/ent.go +++ b/backend/internal/repository/ent.go @@ -5,6 +5,7 @@ package repository import ( "context" "database/sql" + "fmt" "time" "github.com/Wei-Shaw/sub2api/ent" @@ -66,6 +67,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { // 创建 Ent 客户端,绑定到已配置的数据库驱动。 client := ent.NewClient(ent.Driver(drv)) + // 启动阶段:从配置或数据库中确保系统密钥可用。 + if err := ensureBootstrapSecrets(migrationCtx, client, cfg); err != nil { + _ = client.Close() + return nil, nil, err + } + + // 在密钥补齐后执行完整配置校验,避免空 jwt.secret 导致服务运行时失败。 + if err := cfg.Validate(); err != nil { + _ = client.Close() + return nil, nil, fmt.Errorf("validate config after secret bootstrap: %w", err) + } + // SIMPLE 模式:启动时补齐各平台默认分组。 // - anthropic/openai/gemini: 确保存在 -default // - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景) diff --git a/backend/internal/repository/error_passthrough_repo.go b/backend/internal/repository/error_passthrough_repo.go index a58ab60f..ae989359 100644 --- a/backend/internal/repository/error_passthrough_repo.go +++ b/backend/internal/repository/error_passthrough_repo.go @@ -54,7 +54,8 @@ func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.Err SetPriority(rule.Priority). SetMatchMode(rule.MatchMode). SetPassthroughCode(rule.PassthroughCode). - SetPassthroughBody(rule.PassthroughBody) + SetPassthroughBody(rule.PassthroughBody). + SetSkipMonitoring(rule.SkipMonitoring) if len(rule.ErrorCodes) > 0 { builder.SetErrorCodes(rule.ErrorCodes) @@ -90,7 +91,8 @@ func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.Err SetPriority(rule.Priority). SetMatchMode(rule.MatchMode). SetPassthroughCode(rule.PassthroughCode). - SetPassthroughBody(rule.PassthroughBody) + SetPassthroughBody(rule.PassthroughBody). + SetSkipMonitoring(rule.SkipMonitoring) // 处理可选字段 if len(rule.ErrorCodes) > 0 { @@ -149,6 +151,7 @@ func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model Platforms: e.Platforms, PassthroughCode: e.PassthroughCode, PassthroughBody: e.PassthroughBody, + SkipMonitoring: e.SkipMonitoring, CreatedAt: e.CreatedAt, UpdatedAt: e.UpdatedAt, } diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 9365252a..58291b66 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -11,63 +11,6 @@ import ( const stickySessionPrefix = "sticky_session:" -// Gemini Trie Lua 脚本 -const ( - // geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本 - // KEYS[1] = trie key - // ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d") - // ARGV[2] = TTL seconds (用于刷新) - // 返回: 最长匹配的 value (uuid:accountID) 或 nil - // 查找成功时自动刷新 TTL,防止活跃会话意外过期 - geminiTrieFindScript = ` -local chain = ARGV[1] -local ttl = tonumber(ARGV[2]) -local lastMatch = nil -local path = "" - -for part in string.gmatch(chain, "[^-]+") do - path = path == "" and part or path .. "-" .. part - local val = redis.call('HGET', KEYS[1], path) - if val and val ~= "" then - lastMatch = val - end -end - -if lastMatch then - redis.call('EXPIRE', KEYS[1], ttl) -end - -return lastMatch -` - - // geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本 - // KEYS[1] = trie key - // ARGV[1] = digestChain - // ARGV[2] = value (uuid:accountID) - // ARGV[3] = TTL seconds - geminiTrieSaveScript = ` -local chain = ARGV[1] -local value = ARGV[2] -local ttl = tonumber(ARGV[3]) -local path = "" - -for part in string.gmatch(chain, "[^-]+") do - path = path == "" and part or path .. "-" .. part -end -redis.call('HSET', KEYS[1], path, value) -redis.call('EXPIRE', KEYS[1], ttl) -return "OK" -` -) - -// 模型负载统计相关常量 -const ( - modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀 - modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀 - modelLoadTTL = 24 * time.Hour // 调用次数 TTL(24 小时无调用后清零) - modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL -) - type gatewayCache struct { rdb *redis.Client } @@ -108,133 +51,3 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64 key := buildSessionKey(groupID, sessionHash) return c.rdb.Del(ctx, key).Err() } - -// ============ Antigravity 模型负载统计方法 ============ - -// modelLoadKey 构建模型调用次数 key -// 格式: ag:model_load:{accountID}:{model} -func modelLoadKey(accountID int64, model string) string { - return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model) -} - -// modelLastUsedKey 构建模型最后调度时间 key -// 格式: ag:model_last_used:{accountID}:{model} -func modelLastUsedKey(accountID int64, model string) string { - return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model) -} - -// IncrModelCallCount 增加模型调用次数并更新最后调度时间 -// 返回更新后的调用次数 -func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { - loadKey := modelLoadKey(accountID, model) - lastUsedKey := modelLastUsedKey(accountID, model) - - pipe := c.rdb.Pipeline() - incrCmd := pipe.Incr(ctx, loadKey) - pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL - pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL) - if _, err := pipe.Exec(ctx); err != nil { - return 0, err - } - return incrCmd.Val(), nil -} - -// GetModelLoadBatch 批量获取账号的模型负载信息 -func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) { - if len(accountIDs) == 0 { - return make(map[int64]*service.ModelLoadInfo), nil - } - - loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model) - return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil -} - -// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作 -func (c *gatewayCache) pipelineModelLoadGet( - ctx context.Context, - accountIDs []int64, - model string, -) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) { - pipe := c.rdb.Pipeline() - loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs)) - lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs)) - - for _, id := range accountIDs { - loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model)) - lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model)) - } - _, _ = pipe.Exec(ctx) // 忽略错误,key 不存在是正常的 - return loadCmds, lastUsedCmds -} - -// parseModelLoadResults 解析 Pipeline 结果 -func (c *gatewayCache) parseModelLoadResults( - accountIDs []int64, - loadCmds map[int64]*redis.StringCmd, - lastUsedCmds map[int64]*redis.StringCmd, -) map[int64]*service.ModelLoadInfo { - result := make(map[int64]*service.ModelLoadInfo, len(accountIDs)) - for _, id := range accountIDs { - result[id] = &service.ModelLoadInfo{ - CallCount: getInt64OrZero(loadCmds[id]), - LastUsedAt: getTimeOrZero(lastUsedCmds[id]), - } - } - return result -} - -// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0 -func getInt64OrZero(cmd *redis.StringCmd) int64 { - val, _ := cmd.Int64() - return val -} - -// getTimeOrZero 从 StringCmd 获取 time.Time,失败返回零值 -func getTimeOrZero(cmd *redis.StringCmd) time.Time { - val, err := cmd.Int64() - if err != nil { - return time.Time{} - } - return time.Unix(val, 0) -} - -// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============ - -// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询) -// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL -func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - if digestChain == "" { - return "", 0, false - } - - trieKey := service.BuildGeminiTrieKey(groupID, prefixHash) - ttlSeconds := int(service.GeminiSessionTTL().Seconds()) - - // 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返 - // 查找成功时自动刷新 TTL,防止活跃会话意外过期 - result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result() - if err != nil || result == nil { - return "", 0, false - } - - value, ok := result.(string) - if !ok || value == "" { - return "", 0, false - } - - uuid, accountID, ok = service.ParseGeminiSessionValue(value) - return uuid, accountID, ok -} - -// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本) -func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - if digestChain == "" { - return nil - } - - trieKey := service.BuildGeminiTrieKey(groupID, prefixHash) - value := service.FormatGeminiSessionValue(uuid, accountID) - ttlSeconds := int(service.GeminiSessionTTL().Seconds()) - - return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err() -} diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go index fc8e7372..0eebc33f 100644 --- a/backend/internal/repository/gateway_cache_integration_test.go +++ b/backend/internal/repository/gateway_cache_integration_test.go @@ -104,158 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") } -// ============ Gemini Trie 会话测试 ============ - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() { - groupID := int64(1) - prefixHash := "testprefix" - digestChain := "u:hash1-m:hash2-u:hash3" - uuid := "test-uuid-123" - accountID := int64(42) - - // 保存会话 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID) - require.NoError(s.T(), err, "SaveGeminiSession") - - // 精确匹配查找 - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain) - require.True(s.T(), found, "should find exact match") - require.Equal(s.T(), uuid, foundUUID) - require.Equal(s.T(), accountID, foundAccountID) -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() { - groupID := int64(1) - prefixHash := "prefixmatch" - shortChain := "u:a-m:b" - longChain := "u:a-m:b-u:c-m:d" - uuid := "uuid-prefix" - accountID := int64(100) - - // 保存短链 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID) - require.NoError(s.T(), err) - - // 用长链查找,应该匹配到短链(前缀匹配) - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain) - require.True(s.T(), found, "should find prefix match") - require.Equal(s.T(), uuid, foundUUID) - require.Equal(s.T(), accountID, foundAccountID) -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() { - groupID := int64(1) - prefixHash := "longestmatch" - - // 保存多个不同长度的链 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1) - require.NoError(s.T(), err) - err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2) - require.NoError(s.T(), err) - err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3) - require.NoError(s.T(), err) - - // 查找更长的链,应该匹配到最长的前缀 - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e") - require.True(s.T(), found, "should find longest prefix match") - require.Equal(s.T(), "uuid-long", foundUUID) - require.Equal(s.T(), int64(3), foundAccountID) - - // 查找中等长度的链 - foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x") - require.True(s.T(), found) - require.Equal(s.T(), "uuid-medium", foundUUID) - require.Equal(s.T(), int64(2), foundAccountID) -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() { - groupID := int64(1) - prefixHash := "nomatch" - digestChain := "u:a-m:b" - - // 保存一个会话 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1) - require.NoError(s.T(), err) - - // 用不同的链查找,应该找不到 - _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y") - require.False(s.T(), found, "should not find non-matching chain") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() { - groupID := int64(1) - digestChain := "u:a-m:b" - - // 保存到 prefixHash1 - err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1) - require.NoError(s.T(), err) - - // 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离) - _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain) - require.False(s.T(), found, "different prefixHash should be isolated") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() { - prefixHash := "sameprefix" - digestChain := "u:a-m:b" - - // 保存到 groupID 1 - err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1) - require.NoError(s.T(), err) - - // 用 groupID 2 查找,应该找不到(分组隔离) - _, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain) - require.False(s.T(), found, "different groupID should be isolated") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() { - groupID := int64(1) - prefixHash := "emptytest" - - // 空链不应该保存 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1) - require.NoError(s.T(), err, "empty chain should not error") - - // 空链查找应该返回 false - _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "") - require.False(s.T(), found, "empty chain should not match") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() { - groupID := int64(1) - prefixHash := "multisession" - - // 保存多个不同会话(模拟 1000 个并发会话的场景) - sessions := []struct { - chain string - uuid string - accountID int64 - }{ - {"u:session1", "uuid-1", 1}, - {"u:session2-m:reply2", "uuid-2", 2}, - {"u:session3-m:reply3-u:msg3", "uuid-3", 3}, - } - - for _, sess := range sessions { - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID) - require.NoError(s.T(), err) - } - - // 验证每个会话都能正确查找 - for _, sess := range sessions { - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain) - require.True(s.T(), found, "should find session: %s", sess.chain) - require.Equal(s.T(), sess.uuid, foundUUID) - require.Equal(s.T(), sess.accountID, foundAccountID) - } - - // 验证继续对话的场景 - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg") - require.True(s.T(), found) - require.Equal(s.T(), "uuid-2", foundUUID) - require.Equal(s.T(), int64(2), foundAccountID) -} - func TestGatewayCacheSuite(t *testing.T) { suite.Run(t, new(GatewayCacheSuite)) } diff --git a/backend/internal/repository/gateway_cache_model_load_integration_test.go b/backend/internal/repository/gateway_cache_model_load_integration_test.go deleted file mode 100644 index de6fa5ae..00000000 --- a/backend/internal/repository/gateway_cache_model_load_integration_test.go +++ /dev/null @@ -1,234 +0,0 @@ -//go:build integration - -package repository - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -// ============ Gateway Cache 模型负载统计集成测试 ============ - -type GatewayCacheModelLoadSuite struct { - suite.Suite -} - -func TestGatewayCacheModelLoadSuite(t *testing.T) { - suite.Run(t, new(GatewayCacheModelLoadSuite)) -} - -func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_Basic() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - accountID := int64(123) - model := "claude-sonnet-4-20250514" - - // 首次调用应返回 1 - count1, err := cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - require.Equal(t, int64(1), count1) - - // 第二次调用应返回 2 - count2, err := cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - require.Equal(t, int64(2), count2) - - // 第三次调用应返回 3 - count3, err := cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - require.Equal(t, int64(3), count3) -} - -func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentModels() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - accountID := int64(456) - model1 := "claude-sonnet-4-20250514" - model2 := "claude-opus-4-5-20251101" - - // 不同模型应该独立计数 - count1, err := cache.IncrModelCallCount(ctx, accountID, model1) - require.NoError(t, err) - require.Equal(t, int64(1), count1) - - count2, err := cache.IncrModelCallCount(ctx, accountID, model2) - require.NoError(t, err) - require.Equal(t, int64(1), count2) - - count1Again, err := cache.IncrModelCallCount(ctx, accountID, model1) - require.NoError(t, err) - require.Equal(t, int64(2), count1Again) -} - -func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentAccounts() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - account1 := int64(111) - account2 := int64(222) - model := "gemini-2.5-pro" - - // 不同账号应该独立计数 - count1, err := cache.IncrModelCallCount(ctx, account1, model) - require.NoError(t, err) - require.Equal(t, int64(1), count1) - - count2, err := cache.IncrModelCallCount(ctx, account2, model) - require.NoError(t, err) - require.Equal(t, int64(1), count2) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_Empty() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - result, err := cache.GetModelLoadBatch(ctx, []int64{}, "any-model") - require.NoError(t, err) - require.NotNil(t, result) - require.Empty(t, result) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_NonExistent() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - // 查询不存在的账号应返回零值 - result, err := cache.GetModelLoadBatch(ctx, []int64{9999, 9998}, "claude-sonnet-4-20250514") - require.NoError(t, err) - require.Len(t, result, 2) - - require.Equal(t, int64(0), result[9999].CallCount) - require.True(t, result[9999].LastUsedAt.IsZero()) - require.Equal(t, int64(0), result[9998].CallCount) - require.True(t, result[9998].LastUsedAt.IsZero()) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_AfterIncrement() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - accountID := int64(789) - model := "claude-sonnet-4-20250514" - - // 先增加调用次数 - beforeIncr := time.Now() - _, err := cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - _, err = cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - _, err = cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - afterIncr := time.Now() - - // 获取负载信息 - result, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model) - require.NoError(t, err) - require.Len(t, result, 1) - - loadInfo := result[accountID] - require.NotNil(t, loadInfo) - require.Equal(t, int64(3), loadInfo.CallCount) - require.False(t, loadInfo.LastUsedAt.IsZero()) - // LastUsedAt 应该在 beforeIncr 和 afterIncr 之间 - require.True(t, loadInfo.LastUsedAt.After(beforeIncr.Add(-time.Second)) || loadInfo.LastUsedAt.Equal(beforeIncr)) - require.True(t, loadInfo.LastUsedAt.Before(afterIncr.Add(time.Second)) || loadInfo.LastUsedAt.Equal(afterIncr)) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_MultipleAccounts() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - model := "claude-opus-4-5-20251101" - account1 := int64(1001) - account2 := int64(1002) - account3 := int64(1003) // 不调用 - - // account1 调用 2 次 - _, err := cache.IncrModelCallCount(ctx, account1, model) - require.NoError(t, err) - _, err = cache.IncrModelCallCount(ctx, account1, model) - require.NoError(t, err) - - // account2 调用 5 次 - for i := 0; i < 5; i++ { - _, err = cache.IncrModelCallCount(ctx, account2, model) - require.NoError(t, err) - } - - // 批量获取 - result, err := cache.GetModelLoadBatch(ctx, []int64{account1, account2, account3}, model) - require.NoError(t, err) - require.Len(t, result, 3) - - require.Equal(t, int64(2), result[account1].CallCount) - require.False(t, result[account1].LastUsedAt.IsZero()) - - require.Equal(t, int64(5), result[account2].CallCount) - require.False(t, result[account2].LastUsedAt.IsZero()) - - require.Equal(t, int64(0), result[account3].CallCount) - require.True(t, result[account3].LastUsedAt.IsZero()) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_ModelIsolation() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - accountID := int64(2001) - model1 := "claude-sonnet-4-20250514" - model2 := "gemini-2.5-pro" - - // 对 model1 调用 3 次 - for i := 0; i < 3; i++ { - _, err := cache.IncrModelCallCount(ctx, accountID, model1) - require.NoError(t, err) - } - - // 获取 model1 的负载 - result1, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model1) - require.NoError(t, err) - require.Equal(t, int64(3), result1[accountID].CallCount) - - // 获取 model2 的负载(应该为 0) - result2, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model2) - require.NoError(t, err) - require.Equal(t, int64(0), result2[accountID].CallCount) -} - -// ============ 辅助函数测试 ============ - -func (s *GatewayCacheModelLoadSuite) TestModelLoadKey_Format() { - t := s.T() - - key := modelLoadKey(123, "claude-sonnet-4") - require.Equal(t, "ag:model_load:123:claude-sonnet-4", key) -} - -func (s *GatewayCacheModelLoadSuite) TestModelLastUsedKey_Format() { - t := s.T() - - key := modelLastUsedKey(456, "gemini-2.5-pro") - require.Equal(t, "ag:model_last_used:456:gemini-2.5-pro", key) -} diff --git a/backend/internal/repository/gemini_drive_client.go b/backend/internal/repository/gemini_drive_client.go new file mode 100644 index 00000000..2e383595 --- /dev/null +++ b/backend/internal/repository/gemini_drive_client.go @@ -0,0 +1,9 @@ +package repository + +import "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + +// NewGeminiDriveClient creates a concrete DriveClient for Google Drive API operations. +// Returned as geminicli.DriveClient interface for DI (Strategy A). +func NewGeminiDriveClient() geminicli.DriveClient { + return geminicli.NewDriveClient() +} diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go index 8b7fe625..eb14f313 100644 --- a/backend/internal/repository/gemini_oauth_client.go +++ b/backend/internal/repository/gemini_oauth_client.go @@ -26,7 +26,10 @@ func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient { } func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) { - client := createGeminiReqClient(proxyURL) + client, err := createGeminiReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } // Use different OAuth clients based on oauthType: // - code_assist: always use built-in Gemini CLI OAuth client (public) @@ -72,7 +75,10 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c } func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { - client := createGeminiReqClient(proxyURL) + client, err := createGeminiReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } oauthCfgInput := geminicli.OAuthConfig{ ClientID: c.cfg.Gemini.OAuth.ClientID, @@ -111,7 +117,7 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh return &tokenResp, nil } -func createGeminiReqClient(proxyURL string) *req.Client { +func createGeminiReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 60 * time.Second, diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index 4f63280d..b5bc6497 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -26,7 +26,11 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo } var out geminicli.LoadCodeAssistResponse - resp, err := createGeminiCliReqClient(proxyURL).R(). + client, err := createGeminiCliReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } + resp, err := client.R(). SetContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Content-Type", "application/json"). @@ -66,7 +70,11 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody) var out geminicli.OnboardUserResponse - resp, err := createGeminiCliReqClient(proxyURL).R(). + client, err := createGeminiCliReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } + resp, err := client.R(). SetContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Content-Type", "application/json"). @@ -98,7 +106,7 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken return &out, nil } -func createGeminiCliReqClient(proxyURL string) *req.Client { +func createGeminiCliReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 30 * time.Second, diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index 03f8cc66..ad1f22e3 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -5,8 +5,10 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net/http" "os" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" @@ -18,14 +20,27 @@ type githubReleaseClient struct { downloadHTTPClient *http.Client } +type githubReleaseClientError struct { + err error +} + // NewGitHubReleaseClient 创建 GitHub Release 客户端 // proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议 -func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient { +// 代理配置失败时行为由 allowDirectOnProxyError 控制: +// - false(默认):返回错误占位客户端,禁止回退到直连 +// - true:回退到直连(仅限管理员显式开启) +func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient { + // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据, + // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。 sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy client init failed, all requests will fail", "service", "github_release", "error", err) + return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} + } sharedClient = &http.Client{Timeout: 30 * time.Second} } @@ -35,6 +50,10 @@ func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient { ProxyURL: proxyURL, }) if err != nil { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy download client init failed, all requests will fail", "service", "github_release", "error", err) + return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} + } downloadClient = &http.Client{Timeout: 10 * time.Minute} } @@ -44,6 +63,18 @@ func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient { } } +func (c *githubReleaseClientError) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) { + return nil, c.err +} + +func (c *githubReleaseClientError) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error { + return c.err +} + +func (c *githubReleaseClientError) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) { + return nil, c.err +} + func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo) diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index d8cec491..4edc8534 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -4,11 +4,13 @@ import ( "context" "database/sql" "errors" - "log" + "fmt" + "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" @@ -47,12 +49,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). + SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). + SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). + SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). + SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). - SetMcpXMLInject(groupIn.MCPXMLInject) + SetMcpXMLInject(groupIn.MCPXMLInject). + SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) // 设置模型路由配置 if groupIn.ModelRouting != nil { @@ -68,7 +75,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er groupIn.CreatedAt = created.CreatedAt groupIn.UpdatedAt = created.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err) } } return translatePersistenceError(err, nil, service.ErrGroupExists) @@ -110,10 +117,47 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). + SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). + SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). + SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). + SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). - SetMcpXMLInject(groupIn.MCPXMLInject) + SetMcpXMLInject(groupIn.MCPXMLInject). + SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) + + // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 + if groupIn.DailyLimitUSD != nil { + builder = builder.SetDailyLimitUsd(*groupIn.DailyLimitUSD) + } else { + builder = builder.ClearDailyLimitUsd() + } + if groupIn.WeeklyLimitUSD != nil { + builder = builder.SetWeeklyLimitUsd(*groupIn.WeeklyLimitUSD) + } else { + builder = builder.ClearWeeklyLimitUsd() + } + if groupIn.MonthlyLimitUSD != nil { + builder = builder.SetMonthlyLimitUsd(*groupIn.MonthlyLimitUSD) + } else { + builder = builder.ClearMonthlyLimitUsd() + } + if groupIn.ImagePrice1K != nil { + builder = builder.SetImagePrice1k(*groupIn.ImagePrice1K) + } else { + builder = builder.ClearImagePrice1k() + } + if groupIn.ImagePrice2K != nil { + builder = builder.SetImagePrice2k(*groupIn.ImagePrice2K) + } else { + builder = builder.ClearImagePrice2k() + } + if groupIn.ImagePrice4K != nil { + builder = builder.SetImagePrice4k(*groupIn.ImagePrice4K) + } else { + builder = builder.ClearImagePrice4k() + } // 处理 FallbackGroupID:nil 时清除,否则设置 if groupIn.FallbackGroupID != nil { @@ -144,7 +188,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er } groupIn.UpdatedAt = updated.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err) } return nil } @@ -155,7 +199,7 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error { return translatePersistenceError(err, service.ErrGroupNotFound, nil) } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err) } return nil } @@ -183,7 +227,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination q = q.Where(group.IsExclusiveEQ(*isExclusive)) } - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } @@ -191,7 +235,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination groups, err := q. Offset(params.Offset()). Limit(params.Limit()). - Order(dbent.Asc(group.FieldID)). + Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)). All(ctx) if err != nil { return nil, nil, err @@ -218,7 +262,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) { groups, err := r.client.Group.Query(). Where(group.StatusEQ(service.StatusActive)). - Order(dbent.Asc(group.FieldID)). + Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)). All(ctx) if err != nil { return nil, err @@ -245,7 +289,7 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) { groups, err := r.client.Group.Query(). Where(group.StatusEQ(service.StatusActive), group.PlatformEQ(platform)). - Order(dbent.Asc(group.FieldID)). + Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)). All(ctx) if err != nil { return nil, err @@ -273,6 +317,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx) } +// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。 +// 返回结构:map[groupID]exists。 +func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) { + result := make(map[int64]bool, len(ids)) + if len(ids) == 0 { + return result, nil + } + + uniqueIDs := make([]int64, 0, len(ids)) + seen := make(map[int64]struct{}, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + uniqueIDs = append(uniqueIDs, id) + result[id] = false + } + if len(uniqueIDs) == 0 { + return result, nil + } + + rows, err := r.sql.QueryContext(ctx, ` + SELECT id + FROM groups + WHERE id = ANY($1) AND deleted_at IS NULL + `, pq.Array(uniqueIDs)) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, err + } + result[id] = true + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { var count int64 if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil { @@ -288,7 +380,7 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou } affected, _ := res.RowsAffected() if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err) } return affected, nil } @@ -398,7 +490,7 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, } } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err) } return affectedUserIDs, nil @@ -492,8 +584,84 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64 // 发送调度器事件 if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err) } return nil } + +// UpdateSortOrders 批量更新分组排序 +func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error { + if len(updates) == 0 { + return nil + } + + // 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。 + sortOrderByID := make(map[int64]int, len(updates)) + groupIDs := make([]int64, 0, len(updates)) + for _, u := range updates { + if u.ID <= 0 { + continue + } + if _, exists := sortOrderByID[u.ID]; !exists { + groupIDs = append(groupIDs, u.ID) + } + sortOrderByID[u.ID] = u.SortOrder + } + if len(groupIDs) == 0 { + return nil + } + + // 与旧实现保持一致:任何不存在/已删除的分组都返回 not found,且不执行更新。 + var existingCount int + if err := scanSingleRow( + ctx, + r.sql, + `SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`, + []any{pq.Array(groupIDs)}, + &existingCount, + ); err != nil { + return err + } + if existingCount != len(groupIDs) { + return service.ErrGroupNotFound + } + + args := make([]any, 0, len(groupIDs)*2+1) + caseClauses := make([]string, 0, len(groupIDs)) + placeholder := 1 + for _, id := range groupIDs { + caseClauses = append(caseClauses, fmt.Sprintf("WHEN $%d THEN $%d", placeholder, placeholder+1)) + args = append(args, id, sortOrderByID[id]) + placeholder += 2 + } + args = append(args, pq.Array(groupIDs)) + + query := fmt.Sprintf(` + UPDATE groups + SET sort_order = CASE id + %s + ELSE sort_order + END + WHERE deleted_at IS NULL AND id = ANY($%d) + `, strings.Join(caseClauses, "\n\t\t\t"), placeholder) + + result, err := r.sql.ExecContext(ctx, query, args...) + if err != nil { + return err + } + affected, err := result.RowsAffected() + if err != nil { + return err + } + if affected != int64(len(groupIDs)) { + return service.ErrGroupNotFound + } + + for _, id := range groupIDs { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil { + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v", id, err) + } + } + return nil +} diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index c31a9ec4..4a849a46 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -352,6 +352,81 @@ func (s *GroupRepoSuite) TestListWithFilters_Search() { }) } +func (s *GroupRepoSuite) TestUpdateSortOrders_BatchCaseWhen() { + g1 := &service.Group{ + Name: "sort-g1", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + g2 := &service.Group{ + Name: "sort-g2", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + g3 := &service.Group{ + Name: "sort-g3", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, g1)) + s.Require().NoError(s.repo.Create(s.ctx, g2)) + s.Require().NoError(s.repo.Create(s.ctx, g3)) + + err := s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{ + {ID: g1.ID, SortOrder: 30}, + {ID: g2.ID, SortOrder: 10}, + {ID: g3.ID, SortOrder: 20}, + {ID: g2.ID, SortOrder: 15}, // 重复 ID 应以最后一次为准 + }) + s.Require().NoError(err) + + got1, err := s.repo.GetByID(s.ctx, g1.ID) + s.Require().NoError(err) + got2, err := s.repo.GetByID(s.ctx, g2.ID) + s.Require().NoError(err) + got3, err := s.repo.GetByID(s.ctx, g3.ID) + s.Require().NoError(err) + s.Require().Equal(30, got1.SortOrder) + s.Require().Equal(15, got2.SortOrder) + s.Require().Equal(20, got3.SortOrder) +} + +func (s *GroupRepoSuite) TestUpdateSortOrders_MissingGroupNoPartialUpdate() { + g1 := &service.Group{ + Name: "sort-no-partial", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, g1)) + + before, err := s.repo.GetByID(s.ctx, g1.ID) + s.Require().NoError(err) + beforeSort := before.SortOrder + + err = s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{ + {ID: g1.ID, SortOrder: 99}, + {ID: 99999999, SortOrder: 1}, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, service.ErrGroupNotFound) + + after, err := s.repo.GetByID(s.ctx, g1.ID) + s.Require().NoError(err) + s.Require().Equal(beforeSort, after.SortOrder) +} + func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { g1 := &service.Group{ Name: "g1", diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index b0f15f19..a4674c1a 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/Wei-Shaw/sub2api/internal/service" @@ -235,7 +236,10 @@ func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID in // TLS 指纹客户端使用独立的缓存键,与普通客户端隔离 func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { isolation := s.getIsolationMode() - proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) + if err != nil { + return nil, err + } // TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀 cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID) poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls" @@ -373,9 +377,8 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac // - proxy: 按代理地址隔离,同一代理共享客户端 // - account: 按账户隔离,同一账户共享客户端(代理变更时重建) // - account_proxy: 按账户+代理组合隔离,最细粒度 -func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { - entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) - return entry +func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { + return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) } // getClientEntry 获取或创建客户端条目 @@ -385,7 +388,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a // 获取隔离模式 isolation := s.getIsolationMode() // 标准化代理 URL 并解析 - proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) + if err != nil { + return nil, err + } // 构建缓存键(根据隔离策略不同) cacheKey := buildCacheKey(isolation, proxyKey, accountID) // 构建连接池配置键(用于检测配置变更) @@ -680,17 +686,18 @@ func buildCacheKey(isolation, proxyKey string, accountID int64) string { // - raw: 原始代理 URL 字符串 // // 返回: -// - string: 标准化的代理键(空或解析失败返回 "direct") -// - *url.URL: 解析后的 URL(空或解析失败返回 nil) -func normalizeProxyURL(raw string) (string, *url.URL) { - proxyURL := strings.TrimSpace(raw) - if proxyURL == "" { - return directProxyKey, nil - } - parsed, err := url.Parse(proxyURL) +// - string: 标准化的代理键(空返回 "direct") +// - *url.URL: 解析后的 URL(空返回 nil) +// - error: 非空代理 URL 解析失败时返回错误(禁止回退到直连) +func normalizeProxyURL(raw string) (string, *url.URL, error) { + _, parsed, err := proxyurl.Parse(raw) if err != nil { - return directProxyKey, nil + return "", nil, err } + if parsed == nil { + return directProxyKey, nil, nil + } + // 规范化:小写 scheme/host,去除路径和查询参数 parsed.Scheme = strings.ToLower(parsed.Scheme) parsed.Host = strings.ToLower(parsed.Host) parsed.Path = "" @@ -710,7 +717,7 @@ func normalizeProxyURL(raw string) (string, *url.URL) { parsed.Host = hostname } } - return parsed.String(), parsed + return parsed.String(), parsed, nil } // defaultPoolSettings 获取默认连接池配置 diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go index 1e7430a3..89892b3b 100644 --- a/backend/internal/repository/http_upstream_benchmark_test.go +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -59,7 +59,10 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { // 模拟优化后的行为,从缓存获取客户端 b.Run("复用", func(b *testing.B) { // 预热:确保客户端已缓存 - entry := svc.getOrCreateClient(proxyURL, 1, 1) + entry, err := svc.getOrCreateClient(proxyURL, 1, 1) + if err != nil { + b.Fatalf("getOrCreateClient: %v", err) + } client := entry.client b.ResetTimer() // 重置计时器,排除预热时间 for i := 0; i < b.N; i++ { diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index fbe44c5e..b3268463 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -44,7 +44,7 @@ func (s *HTTPUpstreamSuite) newService() *httpUpstreamService { // 验证未配置时使用 300 秒默认值 func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { svc := s.newService() - entry := svc.getOrCreateClient("", 0, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") @@ -55,25 +55,27 @@ func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() { s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7} svc := s.newService() - entry := svc.getOrCreateClient("", 0, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") } -// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退 -// 验证解析失败时回退到直连模式 -func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() { +// TestGetOrCreateClient_InvalidURLReturnsError 测试无效代理 URL 返回错误 +// 验证解析失败时拒绝回退到直连模式 +func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLReturnsError() { svc := s.newService() - entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1) - require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback") + _, err := svc.getClientEntry("://bad-proxy-url", 1, 1, false, false) + require.Error(s.T(), err, "expected error for invalid proxy URL") } // TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化 // 验证等价地址能够映射到同一缓存键 func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() { - key1, _ := normalizeProxyURL("http://proxy.local:8080") - key2, _ := normalizeProxyURL("http://proxy.local:8080/") + key1, _, err1 := normalizeProxyURL("http://proxy.local:8080") + require.NoError(s.T(), err1) + key2, _, err2 := normalizeProxyURL("http://proxy.local:8080/") + require.NoError(s.T(), err2) require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match") } @@ -171,8 +173,8 @@ func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 同一代理,不同账户 - entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 2, 3) require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池") require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端") } @@ -183,8 +185,8 @@ func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy} svc := s.newService() // 同一账户,不同代理 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3) require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理") require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端") } @@ -195,8 +197,8 @@ func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 同一账户,先后使用不同代理 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3) require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池") require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池") require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理") @@ -208,7 +210,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 账户并发数为 12 - entry := svc.getOrCreateClient("", 1, 12) + entry := mustGetOrCreateClient(s.T(), svc, "", 1, 12) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") // 连接池参数应与并发数一致 @@ -228,7 +230,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() { } svc := s.newService() // 账户并发数为 0,应使用全局配置 - entry := svc.getOrCreateClient("", 1, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 1, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch") @@ -245,12 +247,12 @@ func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() { } svc := s.newService() // 创建两个客户端,设置不同的最后使用时间 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 1) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 2, 1) atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久 atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano()) // 创建第三个客户端,触发淘汰 - _ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1) + _ = mustGetOrCreateClient(s.T(), svc, "http://proxy-c:8080", 3, 1) require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内") require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理") @@ -264,12 +266,12 @@ func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() { ClientIdleTTLSeconds: 1, // 1 秒空闲超时 } svc := s.newService() - entry1 := svc.getOrCreateClient("", 1, 1) + entry1 := mustGetOrCreateClient(s.T(), svc, "", 1, 1) // 设置为很久之前使用,但有活跃请求 atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano()) atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求 // 创建新客户端,触发淘汰检查 - _ = svc.getOrCreateClient("", 2, 1) + _, _ = svc.getOrCreateClient("", 2, 1) require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收") } @@ -279,6 +281,14 @@ func TestHTTPUpstreamSuite(t *testing.T) { suite.Run(t, new(HTTPUpstreamSuite)) } +// mustGetOrCreateClient 测试辅助函数,调用 getOrCreateClient 并断言无错误 +func mustGetOrCreateClient(t *testing.T, svc *httpUpstreamService, proxyURL string, accountID int64, concurrency int) *upstreamClientEntry { + t.Helper() + entry, err := svc.getOrCreateClient(proxyURL, accountID, concurrency) + require.NoError(t, err, "getOrCreateClient(%q, %d, %d)", proxyURL, accountID, concurrency) + return entry +} + // hasEntry 检查客户端是否存在于缓存中 // 辅助函数,用于验证淘汰逻辑 func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool { diff --git a/backend/internal/repository/idempotency_repo.go b/backend/internal/repository/idempotency_repo.go new file mode 100644 index 00000000..32f2faae --- /dev/null +++ b/backend/internal/repository/idempotency_repo.go @@ -0,0 +1,237 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type idempotencyRepository struct { + sql sqlExecutor +} + +func NewIdempotencyRepository(_ *dbent.Client, sqlDB *sql.DB) service.IdempotencyRepository { + return &idempotencyRepository{sql: sqlDB} +} + +func (r *idempotencyRepository) CreateProcessing(ctx context.Context, record *service.IdempotencyRecord) (bool, error) { + if record == nil { + return false, nil + } + query := ` + INSERT INTO idempotency_records ( + scope, idempotency_key_hash, request_fingerprint, status, locked_until, expires_at + ) VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (scope, idempotency_key_hash) DO NOTHING + RETURNING id, created_at, updated_at + ` + var createdAt time.Time + var updatedAt time.Time + err := scanSingleRow(ctx, r.sql, query, []any{ + record.Scope, + record.IdempotencyKeyHash, + record.RequestFingerprint, + record.Status, + record.LockedUntil, + record.ExpiresAt, + }, &record.ID, &createdAt, &updatedAt) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, err + } + record.CreatedAt = createdAt + record.UpdatedAt = updatedAt + return true, nil +} + +func (r *idempotencyRepository) GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) { + query := ` + SELECT + id, scope, idempotency_key_hash, request_fingerprint, status, response_status, + response_body, error_reason, locked_until, expires_at, created_at, updated_at + FROM idempotency_records + WHERE scope = $1 AND idempotency_key_hash = $2 + ` + record := &service.IdempotencyRecord{} + var responseStatus sql.NullInt64 + var responseBody sql.NullString + var errorReason sql.NullString + var lockedUntil sql.NullTime + err := scanSingleRow(ctx, r.sql, query, []any{scope, keyHash}, + &record.ID, + &record.Scope, + &record.IdempotencyKeyHash, + &record.RequestFingerprint, + &record.Status, + &responseStatus, + &responseBody, + &errorReason, + &lockedUntil, + &record.ExpiresAt, + &record.CreatedAt, + &record.UpdatedAt, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, err + } + if responseStatus.Valid { + v := int(responseStatus.Int64) + record.ResponseStatus = &v + } + if responseBody.Valid { + v := responseBody.String + record.ResponseBody = &v + } + if errorReason.Valid { + v := errorReason.String + record.ErrorReason = &v + } + if lockedUntil.Valid { + v := lockedUntil.Time + record.LockedUntil = &v + } + return record, nil +} + +func (r *idempotencyRepository) TryReclaim( + ctx context.Context, + id int64, + fromStatus string, + now, newLockedUntil, newExpiresAt time.Time, +) (bool, error) { + query := ` + UPDATE idempotency_records + SET status = $2, + locked_until = $3, + error_reason = NULL, + updated_at = NOW(), + expires_at = $4 + WHERE id = $1 + AND status = $5 + AND (locked_until IS NULL OR locked_until <= $6) + ` + res, err := r.sql.ExecContext(ctx, query, + id, + service.IdempotencyStatusProcessing, + newLockedUntil, + newExpiresAt, + fromStatus, + now, + ) + if err != nil { + return false, err + } + affected, err := res.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *idempotencyRepository) ExtendProcessingLock( + ctx context.Context, + id int64, + requestFingerprint string, + newLockedUntil, + newExpiresAt time.Time, +) (bool, error) { + query := ` + UPDATE idempotency_records + SET locked_until = $2, + expires_at = $3, + updated_at = NOW() + WHERE id = $1 + AND status = $4 + AND request_fingerprint = $5 + ` + res, err := r.sql.ExecContext( + ctx, + query, + id, + newLockedUntil, + newExpiresAt, + service.IdempotencyStatusProcessing, + requestFingerprint, + ) + if err != nil { + return false, err + } + affected, err := res.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *idempotencyRepository) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + query := ` + UPDATE idempotency_records + SET status = $2, + response_status = $3, + response_body = $4, + error_reason = NULL, + locked_until = NULL, + expires_at = $5, + updated_at = NOW() + WHERE id = $1 + ` + _, err := r.sql.ExecContext(ctx, query, + id, + service.IdempotencyStatusSucceeded, + responseStatus, + responseBody, + expiresAt, + ) + return err +} + +func (r *idempotencyRepository) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + query := ` + UPDATE idempotency_records + SET status = $2, + error_reason = $3, + locked_until = $4, + expires_at = $5, + updated_at = NOW() + WHERE id = $1 + ` + _, err := r.sql.ExecContext(ctx, query, + id, + service.IdempotencyStatusFailedRetryable, + errorReason, + lockedUntil, + expiresAt, + ) + return err +} + +func (r *idempotencyRepository) DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) { + if limit <= 0 { + limit = 500 + } + query := ` + WITH victims AS ( + SELECT id + FROM idempotency_records + WHERE expires_at <= $1 + ORDER BY expires_at ASC + LIMIT $2 + ) + DELETE FROM idempotency_records + WHERE id IN (SELECT id FROM victims) + ` + res, err := r.sql.ExecContext(ctx, query, now, limit) + if err != nil { + return 0, err + } + return res.RowsAffected() +} diff --git a/backend/internal/repository/idempotency_repo_integration_test.go b/backend/internal/repository/idempotency_repo_integration_test.go new file mode 100644 index 00000000..f163c2f0 --- /dev/null +++ b/backend/internal/repository/idempotency_repo_integration_test.go @@ -0,0 +1,149 @@ +//go:build integration + +package repository + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +// hashedTestValue returns a unique SHA-256 hex string (64 chars) that fits VARCHAR(64) columns. +func hashedTestValue(t *testing.T, prefix string) string { + t.Helper() + sum := sha256.Sum256([]byte(uniqueTestValue(t, prefix))) + return hex.EncodeToString(sum[:]) +} + +func TestIdempotencyRepo_CreateProcessing_CompeteSameKey(t *testing.T) { + tx := testTx(t) + repo := &idempotencyRepository{sql: tx} + ctx := context.Background() + + now := time.Now().UTC() + record := &service.IdempotencyRecord{ + Scope: uniqueTestValue(t, "idem-scope-create"), + IdempotencyKeyHash: hashedTestValue(t, "idem-hash"), + RequestFingerprint: hashedTestValue(t, "idem-fp"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(30 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err := repo.CreateProcessing(ctx, record) + require.NoError(t, err) + require.True(t, owner) + require.NotZero(t, record.ID) + + duplicate := &service.IdempotencyRecord{ + Scope: record.Scope, + IdempotencyKeyHash: record.IdempotencyKeyHash, + RequestFingerprint: hashedTestValue(t, "idem-fp-other"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(30 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err = repo.CreateProcessing(ctx, duplicate) + require.NoError(t, err) + require.False(t, owner, "same scope+key hash should be de-duplicated") +} + +func TestIdempotencyRepo_TryReclaim_StatusAndLockWindow(t *testing.T) { + tx := testTx(t) + repo := &idempotencyRepository{sql: tx} + ctx := context.Background() + + now := time.Now().UTC() + record := &service.IdempotencyRecord{ + Scope: uniqueTestValue(t, "idem-scope-reclaim"), + IdempotencyKeyHash: hashedTestValue(t, "idem-hash-reclaim"), + RequestFingerprint: hashedTestValue(t, "idem-fp-reclaim"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(10 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err := repo.CreateProcessing(ctx, record) + require.NoError(t, err) + require.True(t, owner) + + require.NoError(t, repo.MarkFailedRetryable( + ctx, + record.ID, + "RETRYABLE_FAILURE", + now.Add(-2*time.Second), + now.Add(24*time.Hour), + )) + + newLockedUntil := now.Add(20 * time.Second) + reclaimed, err := repo.TryReclaim( + ctx, + record.ID, + service.IdempotencyStatusFailedRetryable, + now, + newLockedUntil, + now.Add(24*time.Hour), + ) + require.NoError(t, err) + require.True(t, reclaimed, "failed_retryable + expired lock should allow reclaim") + + got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, service.IdempotencyStatusProcessing, got.Status) + require.NotNil(t, got.LockedUntil) + require.True(t, got.LockedUntil.After(now)) + + require.NoError(t, repo.MarkFailedRetryable( + ctx, + record.ID, + "RETRYABLE_FAILURE", + now.Add(20*time.Second), + now.Add(24*time.Hour), + )) + + reclaimed, err = repo.TryReclaim( + ctx, + record.ID, + service.IdempotencyStatusFailedRetryable, + now, + now.Add(40*time.Second), + now.Add(24*time.Hour), + ) + require.NoError(t, err) + require.False(t, reclaimed, "within lock window should not reclaim") +} + +func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) { + tx := testTx(t) + repo := &idempotencyRepository{sql: tx} + ctx := context.Background() + + now := time.Now().UTC() + record := &service.IdempotencyRecord{ + Scope: uniqueTestValue(t, "idem-scope-success"), + IdempotencyKeyHash: hashedTestValue(t, "idem-hash-success"), + RequestFingerprint: hashedTestValue(t, "idem-fp-success"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(10 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err := repo.CreateProcessing(ctx, record) + require.NoError(t, err) + require.True(t, owner) + + require.NoError(t, repo.MarkSucceeded(ctx, record.ID, 200, `{"ok":true}`, now.Add(24*time.Hour))) + + got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, service.IdempotencyStatusSucceeded, got.Status) + require.NotNil(t, got.ResponseStatus) + require.Equal(t, 200, *got.ResponseStatus) + require.NotNil(t, got.ResponseBody) + require.Equal(t, `{"ok":true}`, *got.ResponseBody) + require.Nil(t, got.LockedUntil) +} diff --git a/backend/internal/repository/identity_cache.go b/backend/internal/repository/identity_cache.go index c4986547..6152dd7a 100644 --- a/backend/internal/repository/identity_cache.go +++ b/backend/internal/repository/identity_cache.go @@ -12,7 +12,7 @@ import ( const ( fingerprintKeyPrefix = "fingerprint:" - fingerprintTTL = 24 * time.Hour + fingerprintTTL = 7 * 24 * time.Hour // 7天,配合每24小时懒续期可保持活跃账号永不过期 maskedSessionKeyPrefix = "masked_session:" maskedSessionTTL = 15 * time.Minute ) diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index 5912e50f..a60ba294 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -50,6 +50,23 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions ( // 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。 const migrationsAdvisoryLockID int64 = 694208311321144027 const migrationsLockRetryInterval = 500 * time.Millisecond +const nonTransactionalMigrationSuffix = "_notx.sql" + +type migrationChecksumCompatibilityRule struct { + fileChecksum string + acceptedDBChecksum map[string]struct{} +} + +// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。 +// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。 +var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{ + "054_drop_legacy_cache_columns.sql": { + fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", + acceptedDBChecksum: map[string]struct{}{ + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {}, + }, + }, +} // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 // @@ -147,6 +164,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { if rowErr == nil { // 迁移已应用,验证校验和是否匹配 if existing != checksum { + // 兼容特定历史误改场景(仅白名单规则),其余仍保持严格不可变约束。 + if isMigrationChecksumCompatible(name, existing, checksum) { + continue + } // 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。 // 正确的做法是创建新的迁移文件来进行变更。 return fmt.Errorf( @@ -165,8 +186,34 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { return fmt.Errorf("check migration %s: %w", name, rowErr) } - // 迁移未应用,在事务中执行。 - // 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。 + nonTx, err := validateMigrationExecutionMode(name, content) + if err != nil { + return fmt.Errorf("validate migration %s: %w", name, err) + } + + if nonTx { + // *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。 + // 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。 + statements := splitSQLStatements(content) + for i, stmt := range statements { + trimmed := strings.TrimSpace(stmt) + if trimmed == "" { + continue + } + if stripSQLLineComment(trimmed) == "" { + continue + } + if _, err := db.ExecContext(ctx, trimmed); err != nil { + return fmt.Errorf("apply migration %s (non-tx statement %d): %w", name, i+1, err) + } + } + if _, err := db.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil { + return fmt.Errorf("record migration %s (non-tx): %w", name, err) + } + continue + } + + // 默认迁移在事务中执行,确保原子性:要么完全成功,要么完全回滚。 tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin migration %s: %w", name, err) @@ -268,6 +315,84 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) { return version, version, hash, nil } +func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool { + rule, ok := migrationChecksumCompatibilityRules[name] + if !ok { + return false + } + if rule.fileChecksum != fileChecksum { + return false + } + _, ok = rule.acceptedDBChecksum[dbChecksum] + return ok +} + +func validateMigrationExecutionMode(name, content string) (bool, error) { + normalizedName := strings.ToLower(strings.TrimSpace(name)) + upperContent := strings.ToUpper(content) + nonTx := strings.HasSuffix(normalizedName, nonTransactionalMigrationSuffix) + + if !nonTx { + if strings.Contains(upperContent, "CONCURRENTLY") { + return false, errors.New("CONCURRENTLY statements must be placed in *_notx.sql migrations") + } + return false, nil + } + + if strings.Contains(upperContent, "BEGIN") || strings.Contains(upperContent, "COMMIT") || strings.Contains(upperContent, "ROLLBACK") { + return false, errors.New("*_notx.sql must not contain transaction control statements (BEGIN/COMMIT/ROLLBACK)") + } + + statements := splitSQLStatements(content) + for _, stmt := range statements { + normalizedStmt := strings.ToUpper(stripSQLLineComment(strings.TrimSpace(stmt))) + if normalizedStmt == "" { + continue + } + + if strings.Contains(normalizedStmt, "CONCURRENTLY") { + isCreateIndex := strings.Contains(normalizedStmt, "CREATE") && strings.Contains(normalizedStmt, "INDEX") + isDropIndex := strings.Contains(normalizedStmt, "DROP") && strings.Contains(normalizedStmt, "INDEX") + if !isCreateIndex && !isDropIndex { + return false, errors.New("*_notx.sql currently only supports CREATE/DROP INDEX CONCURRENTLY statements") + } + if isCreateIndex && !strings.Contains(normalizedStmt, "IF NOT EXISTS") { + return false, errors.New("CREATE INDEX CONCURRENTLY in *_notx.sql must include IF NOT EXISTS for idempotency") + } + if isDropIndex && !strings.Contains(normalizedStmt, "IF EXISTS") { + return false, errors.New("DROP INDEX CONCURRENTLY in *_notx.sql must include IF EXISTS for idempotency") + } + continue + } + + return false, errors.New("*_notx.sql must not mix non-CONCURRENTLY SQL statements") + } + + return true, nil +} + +func splitSQLStatements(content string) []string { + parts := strings.Split(content, ";") + out := make([]string, 0, len(parts)) + for _, part := range parts { + if strings.TrimSpace(part) == "" { + continue + } + out = append(out, part) + } + return out +} + +func stripSQLLineComment(s string) string { + lines := strings.Split(s, "\n") + for i, line := range lines { + if idx := strings.Index(line, "--"); idx >= 0 { + lines[i] = line[:idx] + } + } + return strings.TrimSpace(strings.Join(lines, "\n")) +} + // pgAdvisoryLock 获取 PostgreSQL Advisory Lock。 // Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。 // 它非常适合用于应用层面的分布式锁场景,如迁移序列化。 diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go new file mode 100644 index 00000000..54f5b0ec --- /dev/null +++ b/backend/internal/repository/migrations_runner_checksum_test.go @@ -0,0 +1,36 @@ +package repository + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsMigrationChecksumCompatible(t *testing.T) { + t.Run("054历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "054_drop_legacy_cache_columns.sql", + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4", + "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", + ) + require.True(t, ok) + }) + + t.Run("054在未知文件checksum下不兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "054_drop_legacy_cache_columns.sql", + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4", + "0000000000000000000000000000000000000000000000000000000000000000", + ) + require.False(t, ok) + }) + + t.Run("非白名单迁移不兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "001_init.sql", + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4", + "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", + ) + require.False(t, ok) + }) +} diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go new file mode 100644 index 00000000..9f8a94c6 --- /dev/null +++ b/backend/internal/repository/migrations_runner_extra_test.go @@ -0,0 +1,368 @@ +package repository + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "io/fs" + "strings" + "testing" + "testing/fstest" + "time" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestApplyMigrations_NilDB(t *testing.T) { + err := ApplyMigrations(context.Background(), nil) + require.Error(t, err) + require.Contains(t, err.Error(), "nil sql db") +} + +func TestApplyMigrations_DelegatesToApplyMigrationsFS(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnError(errors.New("lock failed")) + + err = ApplyMigrations(context.Background(), db) + require.Error(t, err) + require.Contains(t, err.Error(), "acquire migrations lock") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestLatestMigrationBaseline(t *testing.T) { + t.Run("empty_fs_returns_baseline", func(t *testing.T) { + version, description, hash, err := latestMigrationBaseline(fstest.MapFS{}) + require.NoError(t, err) + require.Equal(t, "baseline", version) + require.Equal(t, "baseline", description) + require.Equal(t, "", hash) + }) + + t.Run("uses_latest_sorted_sql_file", func(t *testing.T) { + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")}, + "010_final.sql": &fstest.MapFile{ + Data: []byte("CREATE TABLE t2(id int);"), + }, + } + version, description, hash, err := latestMigrationBaseline(fsys) + require.NoError(t, err) + require.Equal(t, "010_final", version) + require.Equal(t, "010_final", description) + require.Len(t, hash, 64) + }) + + t.Run("read_file_error", func(t *testing.T) { + fsys := fstest.MapFS{ + "010_bad.sql": &fstest.MapFile{Mode: fs.ModeDir}, + } + _, _, _, err := latestMigrationBaseline(fsys) + require.Error(t, err) + }) +} + +func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) { + require.False(t, isMigrationChecksumCompatible("unknown.sql", "db", "file")) + + var ( + name string + rule migrationChecksumCompatibilityRule + ) + for n, r := range migrationChecksumCompatibilityRules { + name = n + rule = r + break + } + require.NotEmpty(t, name) + + require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", "file-not-match")) + require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", rule.fileChecksum)) + + var accepted string + for checksum := range rule.acceptedDBChecksum { + accepted = checksum + break + } + require.NotEmpty(t, accepted) + require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum)) +} + +func TestEnsureAtlasBaselineAligned(t *testing.T) { + t.Run("skip_when_no_legacy_table", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("create_atlas_and_insert_baseline_when_empty", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectExec("INSERT INTO atlas_schema_revisions"). + WithArgs("002_next", "002_next", 1, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")}, + "002_next.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t2(id int);")}, + } + err = ensureAtlasBaselineAligned(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_checking_legacy_table", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnError(errors.New("exists failed")) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.Error(t, err) + require.Contains(t, err.Error(), "check schema_migrations") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_counting_atlas_rows", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnError(errors.New("count failed")) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.Error(t, err) + require.Contains(t, err.Error(), "count atlas_schema_revisions") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_creating_atlas_table", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions"). + WillReturnError(errors.New("create failed")) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.Error(t, err) + require.Contains(t, err.Error(), "create atlas_schema_revisions") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_inserting_baseline", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectExec("INSERT INTO atlas_schema_revisions"). + WithArgs("001_init", "001_init", 1, sqlmock.AnyArg()). + WillReturnError(errors.New("insert failed")) + + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")}, + } + err = ensureAtlasBaselineAligned(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "insert atlas baseline") + require.NoError(t, mock.ExpectationsWereMet()) + }) +} + +func TestApplyMigrationsFS_ChecksumMismatchRejected(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_init.sql"). + WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow("mismatched-checksum")) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "checksum mismatch") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_CheckMigrationQueryError(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_err.sql"). + WillReturnError(errors.New("query failed")) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_err.sql": &fstest.MapFile{Data: []byte("SELECT 1;")}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "check migration 001_err.sql") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + + alreadySQL := "CREATE TABLE t(id int);" + checksum := migrationChecksum(alreadySQL) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_already.sql"). + WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow(checksum)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "000_empty.sql": &fstest.MapFile{Data: []byte(" \n\t ")}, + "001_already.sql": &fstest.MapFile{Data: []byte(alreadySQL)}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_ReadMigrationError(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_bad.sql": &fstest.MapFile{Mode: fs.ModeDir}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "read migration 001_bad.sql") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPgAdvisoryLockAndUnlock_ErrorBranches(t *testing.T) { + t.Run("context_cancelled_while_not_locked", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false)) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + err = pgAdvisoryLock(ctx, db) + require.Error(t, err) + require.Contains(t, err.Error(), "acquire migrations lock") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("unlock_exec_error", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnError(errors.New("unlock failed")) + + err = pgAdvisoryUnlock(context.Background(), db) + require.Error(t, err) + require.Contains(t, err.Error(), "release migrations lock") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("acquire_lock_after_retry", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false)) + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true)) + + ctx, cancel := context.WithTimeout(context.Background(), migrationsLockRetryInterval*3) + defer cancel() + start := time.Now() + err = pgAdvisoryLock(ctx, db) + require.NoError(t, err) + require.GreaterOrEqual(t, time.Since(start), migrationsLockRetryInterval) + require.NoError(t, mock.ExpectationsWereMet()) + }) +} + +func migrationChecksum(content string) string { + sum := sha256.Sum256([]byte(strings.TrimSpace(content))) + return hex.EncodeToString(sum[:]) +} diff --git a/backend/internal/repository/migrations_runner_notx_test.go b/backend/internal/repository/migrations_runner_notx_test.go new file mode 100644 index 00000000..db1183cd --- /dev/null +++ b/backend/internal/repository/migrations_runner_notx_test.go @@ -0,0 +1,164 @@ +package repository + +import ( + "context" + "database/sql" + "testing" + "testing/fstest" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestValidateMigrationExecutionMode(t *testing.T) { + t.Run("事务迁移包含CONCURRENTLY会被拒绝", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移要求CREATE使用IF NOT EXISTS", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移要求DROP使用IF EXISTS", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_drop_idx_notx.sql", "DROP INDEX CONCURRENTLY idx_a;") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移禁止事务控制语句", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "BEGIN; CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); COMMIT;") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移禁止混用非CONCURRENTLY语句", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); UPDATE t SET a = 1;") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移允许幂等并发索引语句", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", ` +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); +DROP INDEX CONCURRENTLY IF EXISTS idx_b; +`) + require.True(t, nonTx) + require.NoError(t, err) + }) +} + +func TestApplyMigrationsFS_NonTransactionalMigration(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_add_idx_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("001_add_idx_notx.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_add_idx_notx.sql": &fstest.MapFile{ + Data: []byte("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);"), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_NonTransactionalMigration_MultiStatements(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_add_multi_idx_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t\\(b\\)"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("001_add_multi_idx_notx.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_add_multi_idx_notx.sql": &fstest.MapFile{ + Data: []byte(` +-- first +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a); +-- second +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b); +`), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_add_col.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectBegin() + mock.ExpectExec("ALTER TABLE t ADD COLUMN name TEXT"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("001_add_col.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_add_col.sql": &fstest.MapFile{ + Data: []byte("ALTER TABLE t ADD COLUMN name TEXT;"), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func prepareMigrationsBootstrapExpectations(mock sqlmock.Sqlmock) { + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true)) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS schema_migrations"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) +} diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index bc37ee72..72422d18 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -42,12 +42,19 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { // usage_logs: billing_type used by filters/stats requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false) + requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false) + requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false) // settings table should exist var settingsRegclass sql.NullString require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass)) require.True(t, settingsRegclass.Valid, "expected settings table to exist") + // security_secrets table should exist + var securitySecretsRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.security_secrets')").Scan(&securitySecretsRegclass)) + require.True(t, securitySecretsRegclass.Valid, "expected security_secrets table to exist") + // user_allowed_groups table should exist var uagRegclass sql.NullString require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass)) diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index 394d3a1a..dca0b612 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/url" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -21,16 +22,23 @@ type openaiOAuthService struct { tokenURL string } -func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { - client := createOpenAIReqClient(proxyURL) +func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + client, err := createOpenAIReqClient(proxyURL) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err) + } if redirectURI == "" { redirectURI = openai.DefaultRedirectURI } + clientID = strings.TrimSpace(clientID) + if clientID == "" { + clientID = openai.ClientID + } formData := url.Values{} formData.Set("grant_type", "authorization_code") - formData.Set("client_id", openai.ClientID) + formData.Set("client_id", clientID) formData.Set("code", code) formData.Set("redirect_uri", redirectURI) formData.Set("code_verifier", codeVerifier) @@ -56,12 +64,28 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie } func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { - client := createOpenAIReqClient(proxyURL) + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") +} + +func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + // 调用方应始终传入正确的 client_id;为兼容旧数据,未指定时默认使用 OpenAI ClientID + clientID = strings.TrimSpace(clientID) + if clientID == "" { + clientID = openai.ClientID + } + return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) +} + +func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) { + client, err := createOpenAIReqClient(proxyURL) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err) + } formData := url.Values{} formData.Set("grant_type", "refresh_token") formData.Set("refresh_token", refreshToken) - formData.Set("client_id", openai.ClientID) + formData.Set("client_id", clientID) formData.Set("scope", openai.RefreshScopes) var tokenResp openai.TokenResponse @@ -84,7 +108,7 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro return &tokenResp, nil } -func createOpenAIReqClient(proxyURL string) *req.Client { +func createOpenAIReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 120 * time.Second, diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index f9df08c8..44fa291b 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -81,7 +81,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() { _, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`) })) - resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "") + resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "", "") require.NoError(s.T(), err, "ExchangeCode") select { case msg := <-errCh: @@ -136,13 +136,84 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() { require.Equal(s.T(), "rt2", resp.RefreshToken) } +// TestRefreshToken_DefaultsToOpenAIClientID 验证未指定 client_id 时默认使用 OpenAI ClientID, +// 且只发送一次请求(不再盲猜多个 client_id)。 +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() { + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`) + })) + + resp, err := s.svc.RefreshToken(s.ctx, "rt", "") + require.NoError(s.T(), err, "RefreshToken") + require.Equal(s.T(), "at", resp.AccessToken) + // 只发送了一次请求,使用默认的 OpenAI ClientID + require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs) +} + +// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。 +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() { + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) + if clientID == openai.SoraClientID { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + + resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID) + require.NoError(s.T(), err, "RefreshTokenWithClientID") + require.Equal(s.T(), "at-sora", resp.AccessToken) + require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs) +} + +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() { + const customClientID = "custom-client-id" + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) + if clientID != customClientID { + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`) + })) + + resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID) + require.NoError(s.T(), err, "RefreshTokenWithClientID") + require.Equal(s.T(), "at-custom", resp.AccessToken) + require.Equal(s.T(), "rt-custom", resp.RefreshToken) + require.Equal(s.T(), []string{customClientID}, seenClientIDs) +} + func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) _, _ = io.WriteString(w, "bad") })) - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.Error(s.T(), err) require.ErrorContains(s.T(), err, "status 400") require.ErrorContains(s.T(), err, "bad") @@ -152,7 +223,7 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) s.srv.Close() - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.Error(s.T(), err) require.ErrorContains(s.T(), err, "request failed") } @@ -169,7 +240,7 @@ func (s *OpenAIOAuthServiceSuite) TestContextCancel() { done := make(chan error, 1) go func() { - _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "", "") done <- err }() @@ -195,7 +266,30 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() { _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) })) - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "", "") + require.NoError(s.T(), err, "ExchangeCode") + select { + case msg := <-errCh: + require.Fail(s.T(), msg) + default: + } +} + +func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() { + wantClientID := openai.SoraClientID + errCh := make(chan string, 1) + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + if got := r.PostForm.Get("client_id"); got != wantClientID { + errCh <- "client_id mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) + })) + + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", wantClientID) require.NoError(s.T(), err, "ExchangeCode") select { case msg := <-errCh: @@ -213,7 +307,7 @@ func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() { })) s.svc.tokenURL = s.srv.URL + "?x=1" - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.NoError(s.T(), err, "ExchangeCode") select { case <-s.received: @@ -229,7 +323,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() { _, _ = io.WriteString(w, "not-valid-json") })) - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.Error(s.T(), err, "expected error for invalid JSON response") } diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index b04154b7..989573f2 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "database/sql" + "encoding/json" "fmt" "strings" "time" @@ -55,6 +56,10 @@ INSERT INTO ops_error_logs ( upstream_error_message, upstream_error_detail, upstream_errors, + auth_latency_ms, + routing_latency_ms, + upstream_latency_ms, + response_latency_ms, time_to_first_token_ms, request_body, request_body_truncated, @@ -64,7 +69,7 @@ INSERT INTO ops_error_logs ( retry_count, created_at ) VALUES ( - $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34 + $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 ) RETURNING id` var id int64 @@ -97,6 +102,10 @@ INSERT INTO ops_error_logs ( opsNullString(input.UpstreamErrorMessage), opsNullString(input.UpstreamErrorDetail), opsNullString(input.UpstreamErrorsJSON), + opsNullInt64(input.AuthLatencyMs), + opsNullInt64(input.RoutingLatencyMs), + opsNullInt64(input.UpstreamLatencyMs), + opsNullInt64(input.ResponseLatencyMs), opsNullInt64(input.TimeToFirstTokenMs), opsNullString(input.RequestBodyJSON), input.RequestBodyTruncated, @@ -930,6 +939,243 @@ WHERE id = $1` return err } +func (r *opsRepository) BatchInsertSystemLogs(ctx context.Context, inputs []*service.OpsInsertSystemLogInput) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("nil ops repository") + } + if len(inputs) == 0 { + return 0, nil + } + + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return 0, err + } + stmt, err := tx.PrepareContext(ctx, pq.CopyIn( + "ops_system_logs", + "created_at", + "level", + "component", + "message", + "request_id", + "client_request_id", + "user_id", + "account_id", + "platform", + "model", + "extra", + )) + if err != nil { + _ = tx.Rollback() + return 0, err + } + + var inserted int64 + for _, input := range inputs { + if input == nil { + continue + } + createdAt := input.CreatedAt + if createdAt.IsZero() { + createdAt = time.Now().UTC() + } + component := strings.TrimSpace(input.Component) + level := strings.ToLower(strings.TrimSpace(input.Level)) + message := strings.TrimSpace(input.Message) + if level == "" || message == "" { + continue + } + if component == "" { + component = "app" + } + extra := strings.TrimSpace(input.ExtraJSON) + if extra == "" { + extra = "{}" + } + if _, err := stmt.ExecContext( + ctx, + createdAt.UTC(), + level, + component, + message, + opsNullString(input.RequestID), + opsNullString(input.ClientRequestID), + opsNullInt64(input.UserID), + opsNullInt64(input.AccountID), + opsNullString(input.Platform), + opsNullString(input.Model), + extra, + ); err != nil { + _ = stmt.Close() + _ = tx.Rollback() + return inserted, err + } + inserted++ + } + + if _, err := stmt.ExecContext(ctx); err != nil { + _ = stmt.Close() + _ = tx.Rollback() + return inserted, err + } + if err := stmt.Close(); err != nil { + _ = tx.Rollback() + return inserted, err + } + if err := tx.Commit(); err != nil { + return inserted, err + } + return inserted, nil +} + +func (r *opsRepository) ListSystemLogs(ctx context.Context, filter *service.OpsSystemLogFilter) (*service.OpsSystemLogList, error) { + if r == nil || r.db == nil { + return nil, fmt.Errorf("nil ops repository") + } + if filter == nil { + filter = &service.OpsSystemLogFilter{} + } + + page := filter.Page + if page <= 0 { + page = 1 + } + pageSize := filter.PageSize + if pageSize <= 0 { + pageSize = 50 + } + if pageSize > 200 { + pageSize = 200 + } + + where, args, _ := buildOpsSystemLogsWhere(filter) + countSQL := "SELECT COUNT(*) FROM ops_system_logs l " + where + var total int + if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil { + return nil, err + } + + offset := (page - 1) * pageSize + argsWithLimit := append(args, pageSize, offset) + query := ` +SELECT + l.id, + l.created_at, + l.level, + COALESCE(l.component, ''), + COALESCE(l.message, ''), + COALESCE(l.request_id, ''), + COALESCE(l.client_request_id, ''), + l.user_id, + l.account_id, + COALESCE(l.platform, ''), + COALESCE(l.model, ''), + COALESCE(l.extra::text, '{}') +FROM ops_system_logs l +` + where + ` +ORDER BY l.created_at DESC, l.id DESC +LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) + + rows, err := r.db.QueryContext(ctx, query, argsWithLimit...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + logs := make([]*service.OpsSystemLog, 0, pageSize) + for rows.Next() { + item := &service.OpsSystemLog{} + var userID sql.NullInt64 + var accountID sql.NullInt64 + var extraRaw string + if err := rows.Scan( + &item.ID, + &item.CreatedAt, + &item.Level, + &item.Component, + &item.Message, + &item.RequestID, + &item.ClientRequestID, + &userID, + &accountID, + &item.Platform, + &item.Model, + &extraRaw, + ); err != nil { + return nil, err + } + if userID.Valid { + v := userID.Int64 + item.UserID = &v + } + if accountID.Valid { + v := accountID.Int64 + item.AccountID = &v + } + extraRaw = strings.TrimSpace(extraRaw) + if extraRaw != "" && extraRaw != "null" && extraRaw != "{}" { + extra := make(map[string]any) + if err := json.Unmarshal([]byte(extraRaw), &extra); err == nil { + item.Extra = extra + } + } + logs = append(logs, item) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return &service.OpsSystemLogList{ + Logs: logs, + Total: total, + Page: page, + PageSize: pageSize, + }, nil +} + +func (r *opsRepository) DeleteSystemLogs(ctx context.Context, filter *service.OpsSystemLogCleanupFilter) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("nil ops repository") + } + if filter == nil { + filter = &service.OpsSystemLogCleanupFilter{} + } + + where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter) + if !hasConstraint { + return 0, fmt.Errorf("cleanup requires at least one filter condition") + } + + query := "DELETE FROM ops_system_logs l " + where + res, err := r.db.ExecContext(ctx, query, args...) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (r *opsRepository) InsertSystemLogCleanupAudit(ctx context.Context, input *service.OpsSystemLogCleanupAudit) error { + if r == nil || r.db == nil { + return fmt.Errorf("nil ops repository") + } + if input == nil { + return fmt.Errorf("nil input") + } + createdAt := input.CreatedAt + if createdAt.IsZero() { + createdAt = time.Now().UTC() + } + _, err := r.db.ExecContext(ctx, ` +INSERT INTO ops_system_log_cleanup_audits ( + created_at, + operator_id, + conditions, + deleted_rows +) VALUES ($1,$2,$3,$4) +`, createdAt.UTC(), input.OperatorID, input.Conditions, input.DeletedRows) + return err +} + func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { clauses := make([]string, 0, 12) args := make([]any, 0, 12) @@ -948,7 +1194,7 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } // Keep list endpoints scoped to client errors unless explicitly filtering upstream phase. if phaseFilter != "upstream" { - clauses = append(clauses, "COALESCE(status_code, 0) >= 400") + clauses = append(clauses, "COALESCE(e.status_code, 0) >= 400") } if filter.StartTime != nil && !filter.StartTime.IsZero() { @@ -962,33 +1208,33 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } if p := strings.TrimSpace(filter.Platform); p != "" { args = append(args, p) - clauses = append(clauses, "platform = $"+itoa(len(args))) + clauses = append(clauses, "e.platform = $"+itoa(len(args))) } if filter.GroupID != nil && *filter.GroupID > 0 { args = append(args, *filter.GroupID) - clauses = append(clauses, "group_id = $"+itoa(len(args))) + clauses = append(clauses, "e.group_id = $"+itoa(len(args))) } if filter.AccountID != nil && *filter.AccountID > 0 { args = append(args, *filter.AccountID) - clauses = append(clauses, "account_id = $"+itoa(len(args))) + clauses = append(clauses, "e.account_id = $"+itoa(len(args))) } if phase := phaseFilter; phase != "" { args = append(args, phase) - clauses = append(clauses, "error_phase = $"+itoa(len(args))) + clauses = append(clauses, "e.error_phase = $"+itoa(len(args))) } if filter != nil { if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" { args = append(args, owner) - clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args))) + clauses = append(clauses, "LOWER(COALESCE(e.error_owner,'')) = $"+itoa(len(args))) } if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" { args = append(args, source) - clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args))) + clauses = append(clauses, "LOWER(COALESCE(e.error_source,'')) = $"+itoa(len(args))) } } if resolvedFilter != nil { args = append(args, *resolvedFilter) - clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args))) + clauses = append(clauses, "COALESCE(e.resolved,false) = $"+itoa(len(args))) } // View filter: errors vs excluded vs all. @@ -1000,51 +1246,140 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } switch view { case "", "errors": - clauses = append(clauses, "COALESCE(is_business_limited,false) = false") + clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false") case "excluded": - clauses = append(clauses, "COALESCE(is_business_limited,false) = true") + clauses = append(clauses, "COALESCE(e.is_business_limited,false) = true") case "all": // no-op default: // treat unknown as default 'errors' - clauses = append(clauses, "COALESCE(is_business_limited,false) = false") + clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false") } if len(filter.StatusCodes) > 0 { args = append(args, pq.Array(filter.StatusCodes)) - clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")") + clauses = append(clauses, "COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+")") } else if filter.StatusCodesOther { // "Other" means: status codes not in the common list. known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529} args = append(args, pq.Array(known)) - clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))") + clauses = append(clauses, "NOT (COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+"))") } // Exact correlation keys (preferred for request↔upstream linkage). if rid := strings.TrimSpace(filter.RequestID); rid != "" { args = append(args, rid) - clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args))) + clauses = append(clauses, "COALESCE(e.request_id,'') = $"+itoa(len(args))) } if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" { args = append(args, crid) - clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args))) + clauses = append(clauses, "COALESCE(e.client_request_id,'') = $"+itoa(len(args))) } if q := strings.TrimSpace(filter.Query); q != "" { like := "%" + q + "%" args = append(args, like) n := itoa(len(args)) - clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")") + clauses = append(clauses, "(e.request_id ILIKE $"+n+" OR e.client_request_id ILIKE $"+n+" OR e.error_message ILIKE $"+n+")") } if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" { like := "%" + userQuery + "%" args = append(args, like) n := itoa(len(args)) - clauses = append(clauses, "u.email ILIKE $"+n) + clauses = append(clauses, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $"+n+")") } return "WHERE " + strings.Join(clauses, " AND "), args } +func buildOpsSystemLogsWhere(filter *service.OpsSystemLogFilter) (string, []any, bool) { + clauses := make([]string, 0, 10) + args := make([]any, 0, 10) + clauses = append(clauses, "1=1") + hasConstraint := false + + if filter != nil && filter.StartTime != nil && !filter.StartTime.IsZero() { + args = append(args, filter.StartTime.UTC()) + clauses = append(clauses, "l.created_at >= $"+itoa(len(args))) + hasConstraint = true + } + if filter != nil && filter.EndTime != nil && !filter.EndTime.IsZero() { + args = append(args, filter.EndTime.UTC()) + clauses = append(clauses, "l.created_at < $"+itoa(len(args))) + hasConstraint = true + } + if filter != nil { + if v := strings.ToLower(strings.TrimSpace(filter.Level)); v != "" { + args = append(args, v) + clauses = append(clauses, "LOWER(COALESCE(l.level,'')) = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.Component); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.component,'') = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.RequestID); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.request_id,'') = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.ClientRequestID); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.client_request_id,'') = $"+itoa(len(args))) + hasConstraint = true + } + if filter.UserID != nil && *filter.UserID > 0 { + args = append(args, *filter.UserID) + clauses = append(clauses, "l.user_id = $"+itoa(len(args))) + hasConstraint = true + } + if filter.AccountID != nil && *filter.AccountID > 0 { + args = append(args, *filter.AccountID) + clauses = append(clauses, "l.account_id = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.Platform); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.platform,'') = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.Model); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.model,'') = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.Query); v != "" { + like := "%" + v + "%" + args = append(args, like) + n := itoa(len(args)) + clauses = append(clauses, "(l.message ILIKE $"+n+" OR COALESCE(l.request_id,'') ILIKE $"+n+" OR COALESCE(l.client_request_id,'') ILIKE $"+n+" OR COALESCE(l.extra::text,'') ILIKE $"+n+")") + hasConstraint = true + } + } + + return "WHERE " + strings.Join(clauses, " AND "), args, hasConstraint +} + +func buildOpsSystemLogsCleanupWhere(filter *service.OpsSystemLogCleanupFilter) (string, []any, bool) { + if filter == nil { + filter = &service.OpsSystemLogCleanupFilter{} + } + listFilter := &service.OpsSystemLogFilter{ + StartTime: filter.StartTime, + EndTime: filter.EndTime, + Level: filter.Level, + Component: filter.Component, + RequestID: filter.RequestID, + ClientRequestID: filter.ClientRequestID, + UserID: filter.UserID, + AccountID: filter.AccountID, + Platform: filter.Platform, + Model: filter.Model, + Query: filter.Query, + } + return buildOpsSystemLogsWhere(listFilter) +} + // Helpers for nullable args func opsNullString(v any) any { switch s := v.(type) { diff --git a/backend/internal/repository/ops_repo_dashboard.go b/backend/internal/repository/ops_repo_dashboard.go index 85791a9a..b43d6706 100644 --- a/backend/internal/repository/ops_repo_dashboard.go +++ b/backend/internal/repository/ops_repo_dashboard.go @@ -12,6 +12,11 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" ) +const ( + opsRawLatencyQueryTimeout = 2 * time.Second + opsRawPeakQueryTimeout = 1500 * time.Millisecond +) + func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) { if r == nil || r.db == nil { return nil, fmt.Errorf("nil ops repository") @@ -45,15 +50,24 @@ func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *servic func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) { start := filter.StartTime.UTC() end := filter.EndTime.UTC() + degraded := false successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end) if err != nil { return nil, err } - duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end) + latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout) + duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end) + cancelLatency() if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + duration = service.OpsPercentiles{} + ttft = service.OpsPercentiles{} + } else { + return nil, err + } } errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end) @@ -75,20 +89,40 @@ func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *ser qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end) if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } - qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end) + peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout) + qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end) + cancelPeak() if err != nil { - return nil, err - } - tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end) - if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds) tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds) + if degraded { + if qpsCurrent <= 0 { + qpsCurrent = qpsAvg + } + if tpsCurrent <= 0 { + tpsCurrent = tpsAvg + } + if qpsPeak <= 0 { + qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg)) + } + if tpsPeak <= 0 { + tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg)) + } + } return &service.OpsDashboardOverview{ StartTime: start, @@ -230,26 +264,45 @@ func (r *opsRepository) getDashboardOverviewPreaggregated(ctx context.Context, f sla := safeDivideFloat64(float64(successCount), float64(requestCountSLA)) errorRate := safeDivideFloat64(float64(errorCountSLA), float64(requestCountSLA)) upstreamErrorRate := safeDivideFloat64(float64(upstreamExcl), float64(requestCountSLA)) + degraded := false // Keep "current" rates as raw, to preserve realtime semantics. qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end) if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } - // NOTE: peak still uses raw logs (minute granularity). This is typically cheaper than percentile_cont - // and keeps semantics consistent across modes. - qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end) + peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout) + qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end) + cancelPeak() if err != nil { - return nil, err - } - tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end) - if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds) tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds) + if degraded { + if qpsCurrent <= 0 { + qpsCurrent = qpsAvg + } + if tpsCurrent <= 0 { + tpsCurrent = tpsAvg + } + if qpsPeak <= 0 { + qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg)) + } + if tpsPeak <= 0 { + tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg)) + } + } return &service.OpsDashboardOverview{ StartTime: start, @@ -577,9 +630,16 @@ func (r *opsRepository) queryRawPartial(ctx context.Context, filter *service.Ops return nil, err } - duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end) + latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout) + duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end) + cancelLatency() if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + duration = service.OpsPercentiles{} + ttft = service.OpsPercentiles{} + } else { + return nil, err + } } errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end) @@ -735,68 +795,56 @@ FROM usage_logs ul } func (r *opsRepository) queryUsageLatency(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (duration service.OpsPercentiles, ttft service.OpsPercentiles, err error) { - { - join, where, args, _ := buildUsageWhere(filter, start, end, 1) - q := ` + join, where, args, _ := buildUsageWhere(filter, start, end, 1) + q := ` SELECT - percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50, - percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90, - percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95, - percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99, - AVG(duration_ms) AS avg_ms, - MAX(duration_ms) AS max_ms + percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50, + percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90, + percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95, + percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99, + AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg, + MAX(duration_ms) AS duration_max, + percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50, + percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90, + percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95, + percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99, + AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg, + MAX(first_token_ms) AS ttft_max FROM usage_logs ul ` + join + ` -` + where + ` -AND duration_ms IS NOT NULL` +` + where - var p50, p90, p95, p99 sql.NullFloat64 - var avg sql.NullFloat64 - var max sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil { - return service.OpsPercentiles{}, service.OpsPercentiles{}, err - } - duration.P50 = floatToIntPtr(p50) - duration.P90 = floatToIntPtr(p90) - duration.P95 = floatToIntPtr(p95) - duration.P99 = floatToIntPtr(p99) - duration.Avg = floatToIntPtr(avg) - if max.Valid { - v := int(max.Int64) - duration.Max = &v - } + var dP50, dP90, dP95, dP99 sql.NullFloat64 + var dAvg sql.NullFloat64 + var dMax sql.NullInt64 + var tP50, tP90, tP95, tP99 sql.NullFloat64 + var tAvg sql.NullFloat64 + var tMax sql.NullInt64 + if err := r.db.QueryRowContext(ctx, q, args...).Scan( + &dP50, &dP90, &dP95, &dP99, &dAvg, &dMax, + &tP50, &tP90, &tP95, &tP99, &tAvg, &tMax, + ); err != nil { + return service.OpsPercentiles{}, service.OpsPercentiles{}, err } - { - join, where, args, _ := buildUsageWhere(filter, start, end, 1) - q := ` -SELECT - percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50, - percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90, - percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95, - percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99, - AVG(first_token_ms) AS avg_ms, - MAX(first_token_ms) AS max_ms -FROM usage_logs ul -` + join + ` -` + where + ` -AND first_token_ms IS NOT NULL` + duration.P50 = floatToIntPtr(dP50) + duration.P90 = floatToIntPtr(dP90) + duration.P95 = floatToIntPtr(dP95) + duration.P99 = floatToIntPtr(dP99) + duration.Avg = floatToIntPtr(dAvg) + if dMax.Valid { + v := int(dMax.Int64) + duration.Max = &v + } - var p50, p90, p95, p99 sql.NullFloat64 - var avg sql.NullFloat64 - var max sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil { - return service.OpsPercentiles{}, service.OpsPercentiles{}, err - } - ttft.P50 = floatToIntPtr(p50) - ttft.P90 = floatToIntPtr(p90) - ttft.P95 = floatToIntPtr(p95) - ttft.P99 = floatToIntPtr(p99) - ttft.Avg = floatToIntPtr(avg) - if max.Valid { - v := int(max.Int64) - ttft.Max = &v - } + ttft.P50 = floatToIntPtr(tP50) + ttft.P90 = floatToIntPtr(tP90) + ttft.P95 = floatToIntPtr(tP95) + ttft.P99 = floatToIntPtr(tP99) + ttft.Avg = floatToIntPtr(tAvg) + if tMax.Valid { + v := int(tMax.Int64) + ttft.Max = &v } return duration, ttft, nil @@ -854,20 +902,23 @@ func (r *opsRepository) queryCurrentRates(ctx context.Context, filter *service.O return qpsCurrent, tpsCurrent, nil } -func (r *opsRepository) queryPeakQPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) { +func (r *opsRepository) queryPeakRates(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (qpsPeak float64, tpsPeak float64, err error) { usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1) errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next) q := ` WITH usage_buckets AS ( - SELECT date_trunc('minute', ul.created_at) AS bucket, COUNT(*) AS cnt + SELECT + date_trunc('minute', ul.created_at) AS bucket, + COUNT(*) AS req_cnt, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_cnt FROM usage_logs ul ` + usageJoin + ` ` + usageWhere + ` GROUP BY 1 ), error_buckets AS ( - SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS cnt + SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS err_cnt FROM ops_error_logs ` + errorWhere + ` AND COALESCE(status_code, 0) >= 400 @@ -875,47 +926,33 @@ error_buckets AS ( ), combined AS ( SELECT COALESCE(u.bucket, e.bucket) AS bucket, - COALESCE(u.cnt, 0) + COALESCE(e.cnt, 0) AS total + COALESCE(u.req_cnt, 0) + COALESCE(e.err_cnt, 0) AS total_req, + COALESCE(u.token_cnt, 0) AS total_tokens FROM usage_buckets u FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket ) -SELECT COALESCE(MAX(total), 0) FROM combined` +SELECT + COALESCE(MAX(total_req), 0) AS max_req_per_min, + COALESCE(MAX(total_tokens), 0) AS max_tokens_per_min +FROM combined` args := append(usageArgs, errorArgs...) - var maxPerMinute sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil { - return 0, err + var maxReqPerMinute, maxTokensPerMinute sql.NullInt64 + if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxReqPerMinute, &maxTokensPerMinute); err != nil { + return 0, 0, err } - if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 { - return 0, nil + if maxReqPerMinute.Valid && maxReqPerMinute.Int64 > 0 { + qpsPeak = roundTo1DP(float64(maxReqPerMinute.Int64) / 60.0) } - return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil + if maxTokensPerMinute.Valid && maxTokensPerMinute.Int64 > 0 { + tpsPeak = roundTo1DP(float64(maxTokensPerMinute.Int64) / 60.0) + } + return qpsPeak, tpsPeak, nil } -func (r *opsRepository) queryPeakTPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) { - join, where, args, _ := buildUsageWhere(filter, start, end, 1) - - q := ` -SELECT COALESCE(MAX(tokens_per_min), 0) -FROM ( - SELECT - date_trunc('minute', ul.created_at) AS bucket, - COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS tokens_per_min - FROM usage_logs ul - ` + join + ` - ` + where + ` - GROUP BY 1 -) t` - - var maxPerMinute sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil { - return 0, err - } - if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 { - return 0, nil - } - return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil +func isQueryTimeoutErr(err error) bool { + return errors.Is(err, context.DeadlineExceeded) } func buildUsageWhere(filter *service.OpsDashboardFilter, start, end time.Time, startIndex int) (join string, where string, args []any, nextIndex int) { diff --git a/backend/internal/repository/ops_repo_dashboard_timeout_test.go b/backend/internal/repository/ops_repo_dashboard_timeout_test.go new file mode 100644 index 00000000..76332ca0 --- /dev/null +++ b/backend/internal/repository/ops_repo_dashboard_timeout_test.go @@ -0,0 +1,22 @@ +package repository + +import ( + "context" + "fmt" + "testing" +) + +func TestIsQueryTimeoutErr(t *testing.T) { + if !isQueryTimeoutErr(context.DeadlineExceeded) { + t.Fatalf("context.DeadlineExceeded should be treated as query timeout") + } + if !isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.DeadlineExceeded)) { + t.Fatalf("wrapped context.DeadlineExceeded should be treated as query timeout") + } + if isQueryTimeoutErr(context.Canceled) { + t.Fatalf("context.Canceled should not be treated as query timeout") + } + if isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.Canceled)) { + t.Fatalf("wrapped context.Canceled should not be treated as query timeout") + } +} diff --git a/backend/internal/repository/ops_repo_error_where_test.go b/backend/internal/repository/ops_repo_error_where_test.go new file mode 100644 index 00000000..9ab1a89a --- /dev/null +++ b/backend/internal/repository/ops_repo_error_where_test.go @@ -0,0 +1,48 @@ +package repository + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) { + filter := &service.OpsErrorLogFilter{ + Query: "ACCESS_DENIED", + } + + where, args := buildOpsErrorLogsWhere(filter) + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 1 { + t.Fatalf("args len = %d, want 1", len(args)) + } + if !strings.Contains(where, "e.request_id ILIKE $") { + t.Fatalf("where should include qualified request_id condition: %s", where) + } + if !strings.Contains(where, "e.client_request_id ILIKE $") { + t.Fatalf("where should include qualified client_request_id condition: %s", where) + } + if !strings.Contains(where, "e.error_message ILIKE $") { + t.Fatalf("where should include qualified error_message condition: %s", where) + } +} + +func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) { + filter := &service.OpsErrorLogFilter{ + UserQuery: "admin@", + } + + where, args := buildOpsErrorLogsWhere(filter) + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 1 { + t.Fatalf("args len = %d, want 1", len(args)) + } + if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") { + t.Fatalf("where should include EXISTS user email condition: %s", where) + } +} diff --git a/backend/internal/repository/ops_repo_latency_histogram_buckets.go b/backend/internal/repository/ops_repo_latency_histogram_buckets.go index cd5bed37..e56903f1 100644 --- a/backend/internal/repository/ops_repo_latency_histogram_buckets.go +++ b/backend/internal/repository/ops_repo_latency_histogram_buckets.go @@ -35,12 +35,12 @@ func latencyHistogramRangeCaseExpr(column string) string { if b.upperMs <= 0 { continue } - _, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label)) + fmt.Fprintf(&sb, "\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label) } // Default bucket. last := latencyHistogramBuckets[len(latencyHistogramBuckets)-1] - _, _ = sb.WriteString(fmt.Sprintf("\tELSE '%s'\n", last.label)) + fmt.Fprintf(&sb, "\tELSE '%s'\n", last.label) _, _ = sb.WriteString("END") return sb.String() } @@ -54,11 +54,11 @@ func latencyHistogramRangeOrderCaseExpr(column string) string { if b.upperMs <= 0 { continue } - _, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN %d\n", column, b.upperMs, order)) + fmt.Fprintf(&sb, "\tWHEN %s < %d THEN %d\n", column, b.upperMs, order) order++ } - _, _ = sb.WriteString(fmt.Sprintf("\tELSE %d\n", order)) + fmt.Fprintf(&sb, "\tELSE %d\n", order) _, _ = sb.WriteString("END") return sb.String() } diff --git a/backend/internal/repository/ops_repo_openai_token_stats.go b/backend/internal/repository/ops_repo_openai_token_stats.go new file mode 100644 index 00000000..6aea416e --- /dev/null +++ b/backend/internal/repository/ops_repo_openai_token_stats.go @@ -0,0 +1,145 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func (r *opsRepository) GetOpenAITokenStats(ctx context.Context, filter *service.OpsOpenAITokenStatsFilter) (*service.OpsOpenAITokenStatsResponse, error) { + if r == nil || r.db == nil { + return nil, fmt.Errorf("nil ops repository") + } + if filter == nil { + return nil, fmt.Errorf("nil filter") + } + if filter.StartTime.IsZero() || filter.EndTime.IsZero() { + return nil, fmt.Errorf("start_time/end_time required") + } + // 允许 start_time == end_time(结果为空),与 service 层校验口径保持一致。 + if filter.StartTime.After(filter.EndTime) { + return nil, fmt.Errorf("start_time must be <= end_time") + } + + dashboardFilter := &service.OpsDashboardFilter{ + StartTime: filter.StartTime.UTC(), + EndTime: filter.EndTime.UTC(), + Platform: strings.TrimSpace(strings.ToLower(filter.Platform)), + GroupID: filter.GroupID, + } + + join, where, baseArgs, next := buildUsageWhere(dashboardFilter, dashboardFilter.StartTime, dashboardFilter.EndTime, 1) + where += " AND ul.model LIKE 'gpt%'" + + baseCTE := ` +WITH stats AS ( + SELECT + ul.model AS model, + COUNT(*)::bigint AS request_count, + ROUND( + AVG( + CASE + WHEN ul.duration_ms > 0 AND ul.output_tokens > 0 + THEN ul.output_tokens * 1000.0 / ul.duration_ms + END + )::numeric, + 2 + )::float8 AS avg_tokens_per_sec, + ROUND(AVG(ul.first_token_ms)::numeric, 2)::float8 AS avg_first_token_ms, + COALESCE(SUM(ul.output_tokens), 0)::bigint AS total_output_tokens, + COALESCE(ROUND(AVG(ul.duration_ms)::numeric, 0), 0)::bigint AS avg_duration_ms, + COUNT(CASE WHEN ul.first_token_ms IS NOT NULL THEN 1 END)::bigint AS requests_with_first_token + FROM usage_logs ul + ` + join + ` + ` + where + ` + GROUP BY ul.model +) +` + + countSQL := baseCTE + `SELECT COUNT(*) FROM stats` + var total int64 + if err := r.db.QueryRowContext(ctx, countSQL, baseArgs...).Scan(&total); err != nil { + return nil, err + } + + querySQL := baseCTE + ` +SELECT + model, + request_count, + avg_tokens_per_sec, + avg_first_token_ms, + total_output_tokens, + avg_duration_ms, + requests_with_first_token +FROM stats +ORDER BY request_count DESC, model ASC` + + args := make([]any, 0, len(baseArgs)+2) + args = append(args, baseArgs...) + + if filter.IsTopNMode() { + querySQL += fmt.Sprintf("\nLIMIT $%d", next) + args = append(args, filter.TopN) + } else { + offset := (filter.Page - 1) * filter.PageSize + querySQL += fmt.Sprintf("\nLIMIT $%d OFFSET $%d", next, next+1) + args = append(args, filter.PageSize, offset) + } + + rows, err := r.db.QueryContext(ctx, querySQL, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + items := make([]*service.OpsOpenAITokenStatsItem, 0, 32) + for rows.Next() { + item := &service.OpsOpenAITokenStatsItem{} + var avgTPS sql.NullFloat64 + var avgFirstToken sql.NullFloat64 + if err := rows.Scan( + &item.Model, + &item.RequestCount, + &avgTPS, + &avgFirstToken, + &item.TotalOutputTokens, + &item.AvgDurationMs, + &item.RequestsWithFirstToken, + ); err != nil { + return nil, err + } + if avgTPS.Valid { + v := avgTPS.Float64 + item.AvgTokensPerSec = &v + } + if avgFirstToken.Valid { + v := avgFirstToken.Float64 + item.AvgFirstTokenMs = &v + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, err + } + + resp := &service.OpsOpenAITokenStatsResponse{ + TimeRange: strings.TrimSpace(filter.TimeRange), + StartTime: dashboardFilter.StartTime, + EndTime: dashboardFilter.EndTime, + Platform: dashboardFilter.Platform, + GroupID: dashboardFilter.GroupID, + Items: items, + Total: total, + } + if filter.IsTopNMode() { + topN := filter.TopN + resp.TopN = &topN + } else { + resp.Page = filter.Page + resp.PageSize = filter.PageSize + } + return resp, nil +} diff --git a/backend/internal/repository/ops_repo_openai_token_stats_test.go b/backend/internal/repository/ops_repo_openai_token_stats_test.go new file mode 100644 index 00000000..bb01d820 --- /dev/null +++ b/backend/internal/repository/ops_repo_openai_token_stats_test.go @@ -0,0 +1,156 @@ +package repository + +import ( + "context" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestOpsRepositoryGetOpenAITokenStats_PaginationMode(t *testing.T) { + db, mock := newSQLMock(t) + repo := &opsRepository{db: db} + + start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + groupID := int64(9) + + filter := &service.OpsOpenAITokenStatsFilter{ + TimeRange: "1d", + StartTime: start, + EndTime: end, + Platform: " OpenAI ", + GroupID: &groupID, + Page: 2, + PageSize: 10, + } + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`). + WithArgs(start, end, groupID, "openai"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(3))) + + rows := sqlmock.NewRows([]string{ + "model", + "request_count", + "avg_tokens_per_sec", + "avg_first_token_ms", + "total_output_tokens", + "avg_duration_ms", + "requests_with_first_token", + }). + AddRow("gpt-4o-mini", int64(20), 21.56, 120.34, int64(3000), int64(850), int64(18)). + AddRow("gpt-4.1", int64(20), 10.2, 240.0, int64(2500), int64(900), int64(20)) + + mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$5 OFFSET \$6`). + WithArgs(start, end, groupID, "openai", 10, 10). + WillReturnRows(rows) + + resp, err := repo.GetOpenAITokenStats(context.Background(), filter) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, int64(3), resp.Total) + require.Equal(t, 2, resp.Page) + require.Equal(t, 10, resp.PageSize) + require.Nil(t, resp.TopN) + require.Equal(t, "openai", resp.Platform) + require.NotNil(t, resp.GroupID) + require.Equal(t, groupID, *resp.GroupID) + require.Len(t, resp.Items, 2) + require.Equal(t, "gpt-4o-mini", resp.Items[0].Model) + require.NotNil(t, resp.Items[0].AvgTokensPerSec) + require.InDelta(t, 21.56, *resp.Items[0].AvgTokensPerSec, 0.0001) + require.NotNil(t, resp.Items[0].AvgFirstTokenMs) + require.InDelta(t, 120.34, *resp.Items[0].AvgFirstTokenMs, 0.0001) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestOpsRepositoryGetOpenAITokenStats_TopNMode(t *testing.T) { + db, mock := newSQLMock(t) + repo := &opsRepository{db: db} + + start := time.Date(2026, 1, 1, 10, 0, 0, 0, time.UTC) + end := start.Add(time.Hour) + filter := &service.OpsOpenAITokenStatsFilter{ + TimeRange: "1h", + StartTime: start, + EndTime: end, + TopN: 5, + } + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`). + WithArgs(start, end). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1))) + + rows := sqlmock.NewRows([]string{ + "model", + "request_count", + "avg_tokens_per_sec", + "avg_first_token_ms", + "total_output_tokens", + "avg_duration_ms", + "requests_with_first_token", + }). + AddRow("gpt-4o", int64(5), nil, nil, int64(0), int64(0), int64(0)) + + mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3`). + WithArgs(start, end, 5). + WillReturnRows(rows) + + resp, err := repo.GetOpenAITokenStats(context.Background(), filter) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.TopN) + require.Equal(t, 5, *resp.TopN) + require.Equal(t, 0, resp.Page) + require.Equal(t, 0, resp.PageSize) + require.Len(t, resp.Items, 1) + require.Nil(t, resp.Items[0].AvgTokensPerSec) + require.Nil(t, resp.Items[0].AvgFirstTokenMs) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestOpsRepositoryGetOpenAITokenStats_EmptyResult(t *testing.T) { + db, mock := newSQLMock(t) + repo := &opsRepository{db: db} + + start := time.Date(2026, 1, 2, 0, 0, 0, 0, time.UTC) + end := start.Add(30 * time.Minute) + filter := &service.OpsOpenAITokenStatsFilter{ + TimeRange: "30m", + StartTime: start, + EndTime: end, + Page: 1, + PageSize: 20, + } + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`). + WithArgs(start, end). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0))) + + mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3 OFFSET \$4`). + WithArgs(start, end, 20, 0). + WillReturnRows(sqlmock.NewRows([]string{ + "model", + "request_count", + "avg_tokens_per_sec", + "avg_first_token_ms", + "total_output_tokens", + "avg_duration_ms", + "requests_with_first_token", + })) + + resp, err := repo.GetOpenAITokenStats(context.Background(), filter) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, int64(0), resp.Total) + require.Len(t, resp.Items, 0) + require.Equal(t, 1, resp.Page) + require.Equal(t, 20, resp.PageSize) + + require.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/backend/internal/repository/ops_repo_system_logs_test.go b/backend/internal/repository/ops_repo_system_logs_test.go new file mode 100644 index 00000000..c3524fe4 --- /dev/null +++ b/backend/internal/repository/ops_repo_system_logs_test.go @@ -0,0 +1,86 @@ +package repository + +import ( + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func TestBuildOpsSystemLogsWhere_WithClientRequestIDAndUserID(t *testing.T) { + start := time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC) + end := time.Date(2026, 2, 2, 0, 0, 0, 0, time.UTC) + userID := int64(12) + accountID := int64(34) + + filter := &service.OpsSystemLogFilter{ + StartTime: &start, + EndTime: &end, + Level: "warn", + Component: "http.access", + RequestID: "req-1", + ClientRequestID: "creq-1", + UserID: &userID, + AccountID: &accountID, + Platform: "openai", + Model: "gpt-5", + Query: "timeout", + } + + where, args, hasConstraint := buildOpsSystemLogsWhere(filter) + if !hasConstraint { + t.Fatalf("expected hasConstraint=true") + } + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 11 { + t.Fatalf("args len = %d, want 11", len(args)) + } + if !contains(where, "COALESCE(l.client_request_id,'') = $") { + t.Fatalf("where should include client_request_id condition: %s", where) + } + if !contains(where, "l.user_id = $") { + t.Fatalf("where should include user_id condition: %s", where) + } +} + +func TestBuildOpsSystemLogsCleanupWhere_RequireConstraint(t *testing.T) { + where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(&service.OpsSystemLogCleanupFilter{}) + if hasConstraint { + t.Fatalf("expected hasConstraint=false") + } + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 0 { + t.Fatalf("args len = %d, want 0", len(args)) + } +} + +func TestBuildOpsSystemLogsCleanupWhere_WithClientRequestIDAndUserID(t *testing.T) { + userID := int64(9) + filter := &service.OpsSystemLogCleanupFilter{ + ClientRequestID: "creq-9", + UserID: &userID, + } + + where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter) + if !hasConstraint { + t.Fatalf("expected hasConstraint=true") + } + if len(args) != 2 { + t.Fatalf("args len = %d, want 2", len(args)) + } + if !contains(where, "COALESCE(l.client_request_id,'') = $") { + t.Fatalf("where should include client_request_id condition: %s", where) + } + if !contains(where, "l.user_id = $") { + t.Fatalf("where should include user_id condition: %s", where) + } +} + +func contains(s string, sub string) bool { + return strings.Contains(s, sub) +} diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go index 07d796b8..ee8e1749 100644 --- a/backend/internal/repository/pricing_service.go +++ b/backend/internal/repository/pricing_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "log/slog" "net/http" "strings" "time" @@ -16,14 +17,37 @@ type pricingRemoteClient struct { httpClient *http.Client } +// pricingRemoteClientError 代理初始化失败时的错误占位客户端 +// 所有请求直接返回初始化错误,禁止回退到直连 +type pricingRemoteClientError struct { + err error +} + +func (c *pricingRemoteClientError) FetchPricingJSON(_ context.Context, _ string) ([]byte, error) { + return nil, c.err +} + +func (c *pricingRemoteClientError) FetchHashText(_ context.Context, _ string) (string, error) { + return "", c.err +} + // NewPricingRemoteClient 创建定价数据远程客户端 // proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议 -func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient { +// 代理配置失败时行为由 allowDirectOnProxyError 控制: +// - false(默认):返回错误占位客户端,禁止回退到直连 +// - true:回退到直连(仅限管理员显式开启) +func NewPricingRemoteClient(proxyURL string, allowDirectOnProxyError bool) service.PricingRemoteClient { + // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据, + // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。 sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy client init failed, all requests will fail", "service", "pricing", "error", err) + return &pricingRemoteClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} + } sharedClient = &http.Client{Timeout: 30 * time.Second} } return &pricingRemoteClient{ diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go index 6ea11211..ef2f214b 100644 --- a/backend/internal/repository/pricing_service_test.go +++ b/backend/internal/repository/pricing_service_test.go @@ -19,7 +19,7 @@ type PricingServiceSuite struct { func (s *PricingServiceSuite) SetupTest() { s.ctx = context.Background() - client, ok := NewPricingRemoteClient("").(*pricingRemoteClient) + client, ok := NewPricingRemoteClient("", false).(*pricingRemoteClient) require.True(s.T(), ok, "type assertion failed") s.client = client } @@ -140,6 +140,22 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() { require.Error(s.T(), err) } +func TestNewPricingRemoteClient_InvalidProxy_NoFallback(t *testing.T) { + client := NewPricingRemoteClient("://bad", false) + _, ok := client.(*pricingRemoteClientError) + require.True(t, ok, "should return error client when proxy is invalid and fallback disabled") + + _, err := client.FetchPricingJSON(context.Background(), "http://example.com") + require.Error(t, err) + require.Contains(t, err.Error(), "proxy client init failed") +} + +func TestNewPricingRemoteClient_InvalidProxy_WithFallback(t *testing.T) { + client := NewPricingRemoteClient("://bad", true) + _, ok := client.(*pricingRemoteClient) + require.True(t, ok, "should fallback to direct client when allowed") +} + func TestPricingServiceSuite(t *testing.T) { suite.Run(t, new(PricingServiceSuite)) } diff --git a/backend/internal/repository/promo_code_repo.go b/backend/internal/repository/promo_code_repo.go index 98b422e0..95ce687a 100644 --- a/backend/internal/repository/promo_code_repo.go +++ b/backend/internal/repository/promo_code_repo.go @@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina q = q.Where(promocode.CodeContainsFold(search)) } - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } @@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo q := r.client.PromoCodeUsage.Query(). Where(promocodeusage.PromoCodeIDEQ(promoCodeID)) - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 513e929c..b4aeab71 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { insecure := false allowPrivate := false validateResolvedIP := true + maxResponseBytes := defaultProxyProbeResponseMaxBytes if cfg != nil { insecure = cfg.Security.ProxyProbe.InsecureSkipVerify allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts validateResolvedIP = cfg.Security.URLAllowlist.Enabled + if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 { + maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes + } } if insecure { log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.") @@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { insecureSkipVerify: insecure, allowPrivateHosts: allowPrivate, validateResolvedIP: validateResolvedIP, + maxResponseBytes: maxResponseBytes, } } const ( - defaultProxyProbeTimeout = 30 * time.Second + defaultProxyProbeTimeout = 30 * time.Second + defaultProxyProbeResponseMaxBytes = int64(1024 * 1024) ) // probeURLs 按优先级排列的探测 URL 列表 @@ -52,6 +58,7 @@ type proxyProbeService struct { insecureSkipVerify bool allowPrivateHosts bool validateResolvedIP bool + maxResponseBytes int64 } func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { @@ -59,7 +66,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s ProxyURL: proxyURL, Timeout: defaultProxyProbeTimeout, InsecureSkipVerify: s.insecureSkipVerify, - ProxyStrict: true, ValidateResolvedIP: s.validateResolvedIP, AllowPrivateHosts: s.allowPrivateHosts, }) @@ -98,10 +104,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode) } - body, err := io.ReadAll(resp.Body) + maxResponseBytes := s.maxResponseBytes + if maxResponseBytes <= 0 { + maxResponseBytes = defaultProxyProbeResponseMaxBytes + } + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1)) if err != nil { return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) } + if int64(len(body)) > maxResponseBytes { + return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes) + } switch parser { case "ip-api": diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go index a3a048c3..934a3095 100644 --- a/backend/internal/repository/redeem_code_repo.go +++ b/backend/internal/repository/redeem_code_repo.go @@ -6,6 +6,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -106,7 +107,12 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin q = q.Where(redeemcode.StatusEQ(status)) } if search != "" { - q = q.Where(redeemcode.CodeContainsFold(search)) + q = q.Where( + redeemcode.Or( + redeemcode.CodeContainsFold(search), + redeemcode.HasUserWith(user.EmailContainsFold(search)), + ), + ) } total, err := q.Count(ctx) diff --git a/backend/internal/repository/req_client_pool.go b/backend/internal/repository/req_client_pool.go index af71a7ee..79b24396 100644 --- a/backend/internal/repository/req_client_pool.go +++ b/backend/internal/repository/req_client_pool.go @@ -6,6 +6,8 @@ import ( "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/imroc/req/v3" ) @@ -33,11 +35,11 @@ var sharedReqClients sync.Map // getSharedReqClient 获取共享的 req 客户端实例 // 性能优化:相同配置复用同一客户端,避免重复创建 -func getSharedReqClient(opts reqClientOptions) *req.Client { +func getSharedReqClient(opts reqClientOptions) (*req.Client, error) { key := buildReqClientKey(opts) if cached, ok := sharedReqClients.Load(key); ok { if c, ok := cached.(*req.Client); ok { - return c + return c, nil } } @@ -48,15 +50,19 @@ func getSharedReqClient(opts reqClientOptions) *req.Client { if opts.Impersonate { client = client.ImpersonateChrome() } - if strings.TrimSpace(opts.ProxyURL) != "" { - client.SetProxyURL(strings.TrimSpace(opts.ProxyURL)) + trimmed, _, err := proxyurl.Parse(opts.ProxyURL) + if err != nil { + return nil, err + } + if trimmed != "" { + client.SetProxyURL(trimmed) } actual, _ := sharedReqClients.LoadOrStore(key, client) if c, ok := actual.(*req.Client); ok { - return c + return c, nil } - return client + return client, nil } func buildReqClientKey(opts reqClientOptions) string { diff --git a/backend/internal/repository/req_client_pool_test.go b/backend/internal/repository/req_client_pool_test.go index 904ed4d6..9067d012 100644 --- a/backend/internal/repository/req_client_pool_test.go +++ b/backend/internal/repository/req_client_pool_test.go @@ -26,11 +26,13 @@ func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) { ProxyURL: "http://proxy.local:8080", Timeout: time.Second, } - clientDefault := getSharedReqClient(base) + clientDefault, err := getSharedReqClient(base) + require.NoError(t, err) force := base force.ForceHTTP2 = true - clientForce := getSharedReqClient(force) + clientForce, err := getSharedReqClient(force) + require.NoError(t, err) require.NotSame(t, clientDefault, clientForce) require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force)) @@ -42,8 +44,10 @@ func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) { ProxyURL: "http://proxy.local:8080", Timeout: 2 * time.Second, } - first := getSharedReqClient(opts) - second := getSharedReqClient(opts) + first, err := getSharedReqClient(opts) + require.NoError(t, err) + second, err := getSharedReqClient(opts) + require.NoError(t, err) require.Same(t, first, second) } @@ -56,7 +60,8 @@ func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) { key := buildReqClientKey(opts) sharedReqClients.Store(key, "invalid") - client := getSharedReqClient(opts) + client, err := getSharedReqClient(opts) + require.NoError(t, err) require.NotNil(t, client) loaded, ok := sharedReqClients.Load(key) @@ -71,20 +76,45 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) { Timeout: 4 * time.Second, Impersonate: true, } - client := getSharedReqClient(opts) + client, err := getSharedReqClient(opts) + require.NoError(t, err) require.NotNil(t, client) require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts)) } +func TestGetSharedReqClient_InvalidProxyURL(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: "://missing-scheme", + Timeout: time.Second, + } + _, err := getSharedReqClient(opts) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid proxy URL") +} + +func TestGetSharedReqClient_ProxyURLMissingHost(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: "http://", + Timeout: time.Second, + } + _, err := getSharedReqClient(opts) + require.Error(t, err) + require.Contains(t, err.Error(), "proxy URL missing host") +} + func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) { sharedReqClients = sync.Map{} - client := createOpenAIReqClient("http://proxy.local:8080") + client, err := createOpenAIReqClient("http://proxy.local:8080") + require.NoError(t, err) require.Equal(t, 120*time.Second, client.GetClient().Timeout) } func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) { sharedReqClients = sync.Map{} - client := createGeminiReqClient("http://proxy.local:8080") + client, err := createGeminiReqClient("http://proxy.local:8080") + require.NoError(t, err) require.Equal(t, "", forceHTTPVersion(t, client)) } diff --git a/backend/internal/repository/rpm_cache.go b/backend/internal/repository/rpm_cache.go new file mode 100644 index 00000000..4d73ec4b --- /dev/null +++ b/backend/internal/repository/rpm_cache.go @@ -0,0 +1,141 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "strconv" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +// RPM 计数器缓存常量定义 +// +// 设计说明: +// 使用 Redis 简单计数器跟踪每个账号每分钟的请求数: +// - Key: rpm:{accountID}:{minuteTimestamp} +// - Value: 当前分钟内的请求计数 +// - TTL: 120 秒(覆盖当前分钟 + 一定冗余) +// +// 使用 TxPipeline(MULTI/EXEC)执行 INCR + EXPIRE,保证原子性且兼容 Redis Cluster。 +// 通过 rdb.Time() 获取服务端时间,避免多实例时钟不同步。 +// +// 设计决策: +// - TxPipeline vs Pipeline:Pipeline 仅合并发送但不保证原子,TxPipeline 使用 MULTI/EXEC 事务保证原子执行。 +// - rdb.Time() 单独调用:Pipeline/TxPipeline 中无法引用前一命令的结果,因此 TIME 必须单独调用(2 RTT)。 +// Lua 脚本可以做到 1 RTT,但在 Redis Cluster 中动态拼接 key 存在 CROSSSLOT 风险,选择安全性优先。 +const ( + // RPM 计数器键前缀 + // 格式: rpm:{accountID}:{minuteTimestamp} + rpmKeyPrefix = "rpm:" + + // RPM 计数器 TTL(120 秒,覆盖当前分钟窗口 + 冗余) + rpmKeyTTL = 120 * time.Second +) + +// RPMCacheImpl RPM 计数器缓存 Redis 实现 +type RPMCacheImpl struct { + rdb *redis.Client +} + +// NewRPMCache 创建 RPM 计数器缓存 +func NewRPMCache(rdb *redis.Client) service.RPMCache { + return &RPMCacheImpl{rdb: rdb} +} + +// currentMinuteKey 获取当前分钟的完整 Redis key +// 使用 rdb.Time() 获取 Redis 服务端时间,避免多实例时钟偏差 +func (c *RPMCacheImpl) currentMinuteKey(ctx context.Context, accountID int64) (string, error) { + serverTime, err := c.rdb.Time(ctx).Result() + if err != nil { + return "", fmt.Errorf("redis TIME: %w", err) + } + minuteTS := serverTime.Unix() / 60 + return fmt.Sprintf("%s%d:%d", rpmKeyPrefix, accountID, minuteTS), nil +} + +// currentMinuteSuffix 获取当前分钟时间戳后缀(供批量操作使用) +// 使用 rdb.Time() 获取 Redis 服务端时间 +func (c *RPMCacheImpl) currentMinuteSuffix(ctx context.Context) (string, error) { + serverTime, err := c.rdb.Time(ctx).Result() + if err != nil { + return "", fmt.Errorf("redis TIME: %w", err) + } + minuteTS := serverTime.Unix() / 60 + return strconv.FormatInt(minuteTS, 10), nil +} + +// IncrementRPM 原子递增并返回当前分钟的计数 +// 使用 TxPipeline (MULTI/EXEC) 执行 INCR + EXPIRE,保证原子性且兼容 Redis Cluster +func (c *RPMCacheImpl) IncrementRPM(ctx context.Context, accountID int64) (int, error) { + key, err := c.currentMinuteKey(ctx, accountID) + if err != nil { + return 0, fmt.Errorf("rpm increment: %w", err) + } + + // 使用 TxPipeline (MULTI/EXEC) 保证 INCR + EXPIRE 原子执行 + // EXPIRE 幂等,每次都设置不影响正确性 + pipe := c.rdb.TxPipeline() + incrCmd := pipe.Incr(ctx, key) + pipe.Expire(ctx, key, rpmKeyTTL) + + if _, err := pipe.Exec(ctx); err != nil { + return 0, fmt.Errorf("rpm increment: %w", err) + } + + return int(incrCmd.Val()), nil +} + +// GetRPM 获取当前分钟的 RPM 计数 +func (c *RPMCacheImpl) GetRPM(ctx context.Context, accountID int64) (int, error) { + key, err := c.currentMinuteKey(ctx, accountID) + if err != nil { + return 0, fmt.Errorf("rpm get: %w", err) + } + + val, err := c.rdb.Get(ctx, key).Int() + if errors.Is(err, redis.Nil) { + return 0, nil // 当前分钟无记录 + } + if err != nil { + return 0, fmt.Errorf("rpm get: %w", err) + } + return val, nil +} + +// GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline) +func (c *RPMCacheImpl) GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + if len(accountIDs) == 0 { + return map[int64]int{}, nil + } + + // 获取当前分钟后缀 + minuteSuffix, err := c.currentMinuteSuffix(ctx) + if err != nil { + return nil, fmt.Errorf("rpm batch get: %w", err) + } + + // 使用 Pipeline 批量 GET + pipe := c.rdb.Pipeline() + cmds := make(map[int64]*redis.StringCmd, len(accountIDs)) + for _, id := range accountIDs { + key := fmt.Sprintf("%s%d:%s", rpmKeyPrefix, id, minuteSuffix) + cmds[id] = pipe.Get(ctx, key) + } + + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("rpm batch get: %w", err) + } + + result := make(map[int64]int, len(accountIDs)) + for id, cmd := range cmds { + if val, err := cmd.Int(); err == nil { + result[id] = val + } else { + result[id] = 0 + } + } + return result, nil +} diff --git a/backend/internal/repository/security_secret_bootstrap.go b/backend/internal/repository/security_secret_bootstrap.go new file mode 100644 index 00000000..e773c238 --- /dev/null +++ b/backend/internal/repository/security_secret_bootstrap.go @@ -0,0 +1,177 @@ +package repository + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "log" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" + "github.com/Wei-Shaw/sub2api/internal/config" +) + +const ( + securitySecretKeyJWT = "jwt_secret" + securitySecretReadRetryMax = 5 + securitySecretReadRetryWait = 10 * time.Millisecond +) + +var readRandomBytes = rand.Read + +func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config.Config) error { + if client == nil { + return fmt.Errorf("nil ent client") + } + if cfg == nil { + return fmt.Errorf("nil config") + } + + cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) + if cfg.JWT.Secret != "" { + storedSecret, err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret) + if err != nil { + return fmt.Errorf("persist jwt secret: %w", err) + } + if storedSecret != cfg.JWT.Secret { + log.Println("Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency.") + } + cfg.JWT.Secret = storedSecret + return nil + } + + secret, created, err := getOrCreateGeneratedSecuritySecret(ctx, client, securitySecretKeyJWT, 32) + if err != nil { + return fmt.Errorf("ensure jwt secret: %w", err) + } + cfg.JWT.Secret = secret + + if created { + log.Println("Warning: JWT secret auto-generated and persisted to database. Consider rotating to a managed secret for production.") + } + return nil +} + +func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, key string, byteLength int) (string, bool, error) { + existing, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx) + if err == nil { + value := strings.TrimSpace(existing.Value) + if len([]byte(value)) < 32 { + return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key) + } + return value, false, nil + } + if !ent.IsNotFound(err) { + return "", false, err + } + + generated, err := generateHexSecret(byteLength) + if err != nil { + return "", false, err + } + + if err := client.SecuritySecret.Create(). + SetKey(key). + SetValue(generated). + OnConflictColumns(securitysecret.FieldKey). + DoNothing(). + Exec(ctx); err != nil { + if !isSQLNoRowsError(err) { + return "", false, err + } + } + + stored, err := querySecuritySecretWithRetry(ctx, client, key) + if err != nil { + return "", false, err + } + value := strings.TrimSpace(stored.Value) + if len([]byte(value)) < 32 { + return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key) + } + return value, value == generated, nil +} + +func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) (string, error) { + value = strings.TrimSpace(value) + if len([]byte(value)) < 32 { + return "", fmt.Errorf("secret %q must be at least 32 bytes", key) + } + + if err := client.SecuritySecret.Create(). + SetKey(key). + SetValue(value). + OnConflictColumns(securitysecret.FieldKey). + DoNothing(). + Exec(ctx); err != nil { + if !isSQLNoRowsError(err) { + return "", err + } + } + + stored, err := querySecuritySecretWithRetry(ctx, client, key) + if err != nil { + return "", err + } + storedValue := strings.TrimSpace(stored.Value) + if len([]byte(storedValue)) < 32 { + return "", fmt.Errorf("stored secret %q must be at least 32 bytes", key) + } + return storedValue, nil +} + +func querySecuritySecretWithRetry(ctx context.Context, client *ent.Client, key string) (*ent.SecuritySecret, error) { + var lastErr error + for attempt := 0; attempt <= securitySecretReadRetryMax; attempt++ { + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx) + if err == nil { + return stored, nil + } + if !isSecretNotFoundError(err) { + return nil, err + } + lastErr = err + if attempt == securitySecretReadRetryMax { + break + } + + timer := time.NewTimer(securitySecretReadRetryWait) + select { + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + case <-timer.C: + } + } + return nil, lastErr +} + +func isSecretNotFoundError(err error) bool { + if err == nil { + return false + } + return ent.IsNotFound(err) || isSQLNoRowsError(err) +} + +func isSQLNoRowsError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set") +} + +func generateHexSecret(byteLength int) (string, error) { + if byteLength <= 0 { + byteLength = 32 + } + buf := make([]byte, byteLength) + if _, err := readRandomBytes(buf); err != nil { + return "", fmt.Errorf("generate random secret: %w", err) + } + return hex.EncodeToString(buf), nil +} diff --git a/backend/internal/repository/security_secret_bootstrap_test.go b/backend/internal/repository/security_secret_bootstrap_test.go new file mode 100644 index 00000000..288edf33 --- /dev/null +++ b/backend/internal/repository/security_secret_bootstrap_test.go @@ -0,0 +1,337 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "strings" + "sync" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newSecuritySecretTestClient(t *testing.T) *dbent.Client { + t.Helper() + name := strings.ReplaceAll(t.Name(), "/", "_") + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", name) + + db, err := sql.Open("sqlite", dsn) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +func TestEnsureBootstrapSecretsNilInputs(t *testing.T) { + err := ensureBootstrapSecrets(context.Background(), nil, &config.Config{}) + require.Error(t, err) + require.Contains(t, err.Error(), "nil ent client") + + client := newSecuritySecretTestClient(t) + err = ensureBootstrapSecrets(context.Background(), client, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "nil config") +} + +func TestEnsureBootstrapSecretsGenerateAndPersistJWTSecret(t *testing.T) { + client := newSecuritySecretTestClient(t) + cfg := &config.Config{} + + err := ensureBootstrapSecrets(context.Background(), client, cfg) + require.NoError(t, err) + require.NotEmpty(t, cfg.JWT.Secret) + require.GreaterOrEqual(t, len([]byte(cfg.JWT.Secret)), 32) + + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background()) + require.NoError(t, err) + require.Equal(t, cfg.JWT.Secret, stored.Value) +} + +func TestEnsureBootstrapSecretsLoadExistingJWTSecret(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("existing-jwt-secret-32bytes-long!!!!").Save(context.Background()) + require.NoError(t, err) + + cfg := &config.Config{} + err = ensureBootstrapSecrets(context.Background(), client, cfg) + require.NoError(t, err) + require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret) +} + +func TestEnsureBootstrapSecretsRejectInvalidStoredSecret(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("too-short").Save(context.Background()) + require.NoError(t, err) + + cfg := &config.Config{} + err = ensureBootstrapSecrets(context.Background(), client, cfg) + require.Error(t, err) + require.Contains(t, err.Error(), "at least 32 bytes") +} + +func TestEnsureBootstrapSecretsPersistConfiguredJWTSecret(t *testing.T) { + client := newSecuritySecretTestClient(t) + cfg := &config.Config{ + JWT: config.JWTConfig{Secret: "configured-jwt-secret-32bytes-long!!"}, + } + + err := ensureBootstrapSecrets(context.Background(), client, cfg) + require.NoError(t, err) + + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background()) + require.NoError(t, err) + require.Equal(t, "configured-jwt-secret-32bytes-long!!", stored.Value) +} + +func TestEnsureBootstrapSecretsConfiguredSecretTooShort(t *testing.T) { + client := newSecuritySecretTestClient(t) + cfg := &config.Config{JWT: config.JWTConfig{Secret: "short"}} + + err := ensureBootstrapSecrets(context.Background(), client, cfg) + require.Error(t, err) + require.Contains(t, err.Error(), "at least 32 bytes") +} + +func TestEnsureBootstrapSecretsConfiguredSecretDuplicateIgnored(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := client.SecuritySecret.Create(). + SetKey(securitySecretKeyJWT). + SetValue("existing-jwt-secret-32bytes-long!!!!"). + Save(context.Background()) + require.NoError(t, err) + + cfg := &config.Config{JWT: config.JWTConfig{Secret: "another-configured-jwt-secret-32!!!!"}} + err = ensureBootstrapSecrets(context.Background(), client, cfg) + require.NoError(t, err) + + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background()) + require.NoError(t, err) + require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", stored.Value) + require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret) +} + +func TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := client.SecuritySecret.Create(). + SetKey("trimmed_key"). + SetValue(" existing-trimmed-secret-32bytes-long!! "). + Save(context.Background()) + require.NoError(t, err) + + value, created, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "trimmed_key", 32) + require.NoError(t, err) + require.False(t, created) + require.Equal(t, "existing-trimmed-secret-32bytes-long!!", value) +} + +func TestGetOrCreateGeneratedSecuritySecretQueryError(t *testing.T) { + client := newSecuritySecretTestClient(t) + require.NoError(t, client.Close()) + + _, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "closed_client_key", 32) + require.Error(t, err) +} + +func TestGetOrCreateGeneratedSecuritySecretCreateValidationError(t *testing.T) { + client := newSecuritySecretTestClient(t) + tooLongKey := strings.Repeat("k", 101) + + _, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, tooLongKey, 32) + require.Error(t, err) +} + +func TestGetOrCreateGeneratedSecuritySecretConcurrentCreation(t *testing.T) { + client := newSecuritySecretTestClient(t) + const goroutines = 8 + key := "concurrent_bootstrap_key" + + values := make([]string, goroutines) + createdFlags := make([]bool, goroutines) + errs := make([]error, goroutines) + + var wg sync.WaitGroup + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + values[idx], createdFlags[idx], errs[idx] = getOrCreateGeneratedSecuritySecret(context.Background(), client, key, 32) + }(i) + } + wg.Wait() + + for i := range errs { + require.NoError(t, errs[i]) + require.NotEmpty(t, values[i]) + } + for i := 1; i < len(values); i++ { + require.Equal(t, values[0], values[i]) + } + + createdCount := 0 + for _, created := range createdFlags { + if created { + createdCount++ + } + } + require.GreaterOrEqual(t, createdCount, 1) + require.LessOrEqual(t, createdCount, 1) + + count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Count(context.Background()) + require.NoError(t, err) + require.Equal(t, 1, count) +} + +func TestGetOrCreateGeneratedSecuritySecretGenerateError(t *testing.T) { + client := newSecuritySecretTestClient(t) + originalRead := readRandomBytes + readRandomBytes = func([]byte) (int, error) { + return 0, errors.New("boom") + } + t.Cleanup(func() { + readRandomBytes = originalRead + }) + + _, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "gen_error_key", 32) + require.Error(t, err) + require.Contains(t, err.Error(), "boom") +} + +func TestCreateSecuritySecretIfAbsent(t *testing.T) { + client := newSecuritySecretTestClient(t) + + _, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short") + require.Error(t, err) + require.Contains(t, err.Error(), "at least 32 bytes") + + stored, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long") + require.NoError(t, err) + require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored) + + stored, err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes") + require.NoError(t, err) + require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored) + + count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ("abc")).Count(context.Background()) + require.NoError(t, err) + require.Equal(t, 1, count) +} + +func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := createSecuritySecretIfAbsent( + context.Background(), + client, + strings.Repeat("k", 101), + "valid-jwt-secret-value-32bytes-long", + ) + require.Error(t, err) +} + +func TestCreateSecuritySecretIfAbsentExecError(t *testing.T) { + client := newSecuritySecretTestClient(t) + require.NoError(t, client.Close()) + + _, err := createSecuritySecretIfAbsent(context.Background(), client, "closed-client-key", "valid-jwt-secret-value-32bytes-long") + require.Error(t, err) +} + +func TestQuerySecuritySecretWithRetrySuccess(t *testing.T) { + client := newSecuritySecretTestClient(t) + created, err := client.SecuritySecret.Create(). + SetKey("retry_success_key"). + SetValue("retry-success-jwt-secret-value-32!!"). + Save(context.Background()) + require.NoError(t, err) + + got, err := querySecuritySecretWithRetry(context.Background(), client, "retry_success_key") + require.NoError(t, err) + require.Equal(t, created.ID, got.ID) + require.Equal(t, "retry-success-jwt-secret-value-32!!", got.Value) +} + +func TestQuerySecuritySecretWithRetryExhausted(t *testing.T) { + client := newSecuritySecretTestClient(t) + + _, err := querySecuritySecretWithRetry(context.Background(), client, "retry_missing_key") + require.Error(t, err) + require.True(t, isSecretNotFoundError(err)) +} + +func TestQuerySecuritySecretWithRetryContextCanceled(t *testing.T) { + client := newSecuritySecretTestClient(t) + ctx, cancel := context.WithTimeout(context.Background(), securitySecretReadRetryWait/2) + defer cancel() + + _, err := querySecuritySecretWithRetry(ctx, client, "retry_ctx_cancel_key") + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestQuerySecuritySecretWithRetryNonNotFoundError(t *testing.T) { + client := newSecuritySecretTestClient(t) + require.NoError(t, client.Close()) + + _, err := querySecuritySecretWithRetry(context.Background(), client, "retry_closed_client_key") + require.Error(t, err) + require.False(t, isSecretNotFoundError(err)) +} + +func TestSecretNotFoundHelpers(t *testing.T) { + require.False(t, isSecretNotFoundError(nil)) + require.False(t, isSQLNoRowsError(nil)) + + require.True(t, isSQLNoRowsError(sql.ErrNoRows)) + require.True(t, isSQLNoRowsError(fmt.Errorf("wrapped: %w", sql.ErrNoRows))) + require.True(t, isSQLNoRowsError(errors.New("sql: no rows in result set"))) + + require.True(t, isSecretNotFoundError(sql.ErrNoRows)) + require.True(t, isSecretNotFoundError(errors.New("sql: no rows in result set"))) + require.False(t, isSecretNotFoundError(errors.New("some other error"))) +} + +func TestGenerateHexSecretReadError(t *testing.T) { + originalRead := readRandomBytes + readRandomBytes = func([]byte) (int, error) { + return 0, errors.New("read random failed") + } + t.Cleanup(func() { + readRandomBytes = originalRead + }) + + _, err := generateHexSecret(32) + require.Error(t, err) + require.Contains(t, err.Error(), "read random failed") +} + +func TestGenerateHexSecretLengths(t *testing.T) { + v1, err := generateHexSecret(0) + require.NoError(t, err) + require.Len(t, v1, 64) + _, err = hex.DecodeString(v1) + require.NoError(t, err) + + v2, err := generateHexSecret(16) + require.NoError(t, err) + require.Len(t, v2, 32) + _, err = hex.DecodeString(v2) + require.NoError(t, err) + + require.NotEqual(t, v1, v2) +} diff --git a/backend/internal/repository/sora_account_repo.go b/backend/internal/repository/sora_account_repo.go new file mode 100644 index 00000000..ad2ae638 --- /dev/null +++ b/backend/internal/repository/sora_account_repo.go @@ -0,0 +1,98 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// soraAccountRepository 实现 service.SoraAccountRepository 接口。 +// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。 +// +// 设计说明: +// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理 +// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义 +// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除 +type soraAccountRepository struct { + sql *sql.DB +} + +// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例 +func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository { + return &soraAccountRepository{sql: sqlDB} +} + +// Upsert 创建或更新 Sora 账号扩展信息 +// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert +func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error { + accessToken, accessOK := updates["access_token"].(string) + refreshToken, refreshOK := updates["refresh_token"].(string) + sessionToken, sessionOK := updates["session_token"].(string) + + if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" { + if !sessionOK { + return errors.New("缺少 access_token/refresh_token,且未提供可更新字段") + } + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_accounts + SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END, + updated_at = NOW() + WHERE account_id = $1 + `, accountID, sessionToken) + if err != nil { + return err + } + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return errors.New("sora_accounts 记录不存在,无法仅更新 session_token") + } + return nil + } + + _, err := r.sql.ExecContext(ctx, ` + INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at) + VALUES ($1, $2, $3, $4, NOW(), NOW()) + ON CONFLICT (account_id) DO UPDATE SET + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END, + updated_at = NOW() + `, accountID, accessToken, refreshToken, sessionToken) + return err +} + +// GetByAccountID 根据账号 ID 获取 Sora 扩展信息 +func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) { + rows, err := r.sql.QueryContext(ctx, ` + SELECT account_id, access_token, refresh_token, COALESCE(session_token, '') + FROM sora_accounts + WHERE account_id = $1 + `, accountID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, nil // 记录不存在 + } + + var sa service.SoraAccount + if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil { + return nil, err + } + return &sa, nil +} + +// Delete 删除 Sora 账号扩展信息 +func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error { + _, err := r.sql.ExecContext(ctx, ` + DELETE FROM sora_accounts WHERE account_id = $1 + `, accountID) + return err +} diff --git a/backend/internal/repository/sora_generation_repo.go b/backend/internal/repository/sora_generation_repo.go new file mode 100644 index 00000000..aaf3cb2f --- /dev/null +++ b/backend/internal/repository/sora_generation_repo.go @@ -0,0 +1,419 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。 +// 使用原生 SQL 操作 sora_generations 表。 +type soraGenerationRepository struct { + sql *sql.DB +} + +// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。 +func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository { + return &soraGenerationRepository{sql: sqlDB} +} + +func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error { + mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) + s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) + + err := r.sql.QueryRowContext(ctx, ` + INSERT INTO sora_generations ( + user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + RETURNING id, created_at + `, + gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, + gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, + gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, + ).Scan(&gen.ID, &gen.CreatedAt) + return err +} + +// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。 +func (r *soraGenerationRepository) CreatePendingWithLimit( + ctx context.Context, + gen *service.SoraGeneration, + activeStatuses []string, + maxActive int64, +) error { + if gen == nil { + return fmt.Errorf("generation is nil") + } + if maxActive <= 0 { + return r.Create(ctx, gen) + } + if len(activeStatuses) == 0 { + activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating} + } + + tx, err := r.sql.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + // 使用用户级 advisory lock 串行化并发创建,避免超限竞态。 + if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil { + return err + } + + placeholders := make([]string, len(activeStatuses)) + args := make([]any, 0, 1+len(activeStatuses)) + args = append(args, gen.UserID) + for i, s := range activeStatuses { + placeholders[i] = fmt.Sprintf("$%d", i+2) + args = append(args, s) + } + countQuery := fmt.Sprintf( + `SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`, + strings.Join(placeholders, ","), + ) + var activeCount int64 + if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil { + return err + } + if activeCount >= maxActive { + return service.ErrSoraGenerationConcurrencyLimit + } + + mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) + s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) + if err := tx.QueryRowContext(ctx, ` + INSERT INTO sora_generations ( + user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + RETURNING id, created_at + `, + gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, + gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, + gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, + ).Scan(&gen.ID, &gen.CreatedAt); err != nil { + return err + } + + return tx.Commit() +} + +func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) { + gen := &service.SoraGeneration{} + var mediaURLsJSON, s3KeysJSON []byte + var completedAt sql.NullTime + var apiKeyID sql.NullInt64 + + err := r.sql.QueryRowContext(ctx, ` + SELECT id, user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message, + created_at, completed_at + FROM sora_generations WHERE id = $1 + `, id).Scan( + &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, + &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, + &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, + &gen.CreatedAt, &completedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("生成记录不存在") + } + return nil, err + } + + if apiKeyID.Valid { + gen.APIKeyID = &apiKeyID.Int64 + } + if completedAt.Valid { + gen.CompletedAt = &completedAt.Time + } + _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) + _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) + return gen, nil +} + +func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error { + mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) + s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) + + var completedAt *time.Time + if gen.CompletedAt != nil { + completedAt = gen.CompletedAt + } + + _, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations SET + status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5, + storage_type = $6, s3_object_keys = $7, upstream_task_id = $8, + error_message = $9, completed_at = $10 + WHERE id = $1 + `, + gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, + gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, + gen.ErrorMessage, completedAt, + ) + return err +} + +// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。 +func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, upstream_task_id = $3 + WHERE id = $1 AND status = $4 + `, + id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。 +func (r *soraGenerationRepository) UpdateCompletedIfActive( + ctx context.Context, + id int64, + mediaURL string, + mediaURLs []string, + storageType string, + s3Keys []string, + fileSizeBytes int64, + completedAt time.Time, +) (bool, error) { + mediaURLsJSON, _ := json.Marshal(mediaURLs) + s3KeysJSON, _ := json.Marshal(s3Keys) + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, + media_url = $3, + media_urls = $4, + file_size_bytes = $5, + storage_type = $6, + s3_object_keys = $7, + error_message = '', + completed_at = $8 + WHERE id = $1 AND status IN ($9, $10) + `, + id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes, + storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。 +func (r *soraGenerationRepository) UpdateFailedIfActive( + ctx context.Context, + id int64, + errMsg string, + completedAt time.Time, +) (bool, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, + error_message = $3, + completed_at = $4 + WHERE id = $1 AND status IN ($5, $6) + `, + id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。 +func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, completed_at = $3 + WHERE id = $1 AND status IN ($4, $5) + `, + id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。 +func (r *soraGenerationRepository) UpdateStorageIfCompleted( + ctx context.Context, + id int64, + mediaURL string, + mediaURLs []string, + storageType string, + s3Keys []string, + fileSizeBytes int64, +) (bool, error) { + mediaURLsJSON, _ := json.Marshal(mediaURLs) + s3KeysJSON, _ := json.Marshal(s3Keys) + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET media_url = $2, + media_urls = $3, + file_size_bytes = $4, + storage_type = $5, + s3_object_keys = $6 + WHERE id = $1 AND status = $7 + `, + id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error { + _, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id) + return err +} + +func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { + // 构建 WHERE 条件 + conditions := []string{"user_id = $1"} + args := []any{params.UserID} + argIdx := 2 + + if params.Status != "" { + // 支持逗号分隔的多状态 + statuses := strings.Split(params.Status, ",") + placeholders := make([]string, len(statuses)) + for i, s := range statuses { + placeholders[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, strings.TrimSpace(s)) + argIdx++ + } + conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ","))) + } + if params.StorageType != "" { + storageTypes := strings.Split(params.StorageType, ",") + placeholders := make([]string, len(storageTypes)) + for i, s := range storageTypes { + placeholders[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, strings.TrimSpace(s)) + argIdx++ + } + conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ","))) + } + if params.MediaType != "" { + conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx)) + args = append(args, params.MediaType) + argIdx++ + } + + whereClause := "WHERE " + strings.Join(conditions, " AND ") + + // 计数 + var total int64 + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause) + if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, 0, err + } + + // 分页查询 + offset := (params.Page - 1) * params.PageSize + listQuery := fmt.Sprintf(` + SELECT id, user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message, + created_at, completed_at + FROM sora_generations %s + ORDER BY created_at DESC + LIMIT $%d OFFSET $%d + `, whereClause, argIdx, argIdx+1) + args = append(args, params.PageSize, offset) + + rows, err := r.sql.QueryContext(ctx, listQuery, args...) + if err != nil { + return nil, 0, err + } + defer func() { + _ = rows.Close() + }() + + var results []*service.SoraGeneration + for rows.Next() { + gen := &service.SoraGeneration{} + var mediaURLsJSON, s3KeysJSON []byte + var completedAt sql.NullTime + var apiKeyID sql.NullInt64 + + if err := rows.Scan( + &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, + &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, + &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, + &gen.CreatedAt, &completedAt, + ); err != nil { + return nil, 0, err + } + + if apiKeyID.Valid { + gen.APIKeyID = &apiKeyID.Int64 + } + if completedAt.Valid { + gen.CompletedAt = &completedAt.Time + } + _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) + _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) + results = append(results, gen) + } + + return results, total, rows.Err() +} + +func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) { + if len(statuses) == 0 { + return 0, nil + } + + placeholders := make([]string, len(statuses)) + args := []any{userID} + for i, s := range statuses { + placeholders[i] = fmt.Sprintf("$%d", i+2) + args = append(args, s) + } + + var count int64 + query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ",")) + err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count) + return count, err +} diff --git a/backend/internal/repository/usage_cleanup_repo.go b/backend/internal/repository/usage_cleanup_repo.go index 9c021357..1a25696e 100644 --- a/backend/internal/repository/usage_cleanup_repo.go +++ b/backend/internal/repository/usage_cleanup_repo.go @@ -362,7 +362,12 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) idx++ } } - if filters.Stream != nil { + if filters.RequestType != nil { + condition, conditionArgs := buildRequestTypeFilterCondition(idx, *filters.RequestType) + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + idx += len(conditionArgs) + } else if filters.Stream != nil { conditions = append(conditions, fmt.Sprintf("stream = $%d", idx)) args = append(args, *filters.Stream) idx++ diff --git a/backend/internal/repository/usage_cleanup_repo_test.go b/backend/internal/repository/usage_cleanup_repo_test.go index 0ca30ec7..1ac7cca5 100644 --- a/backend/internal/repository/usage_cleanup_repo_test.go +++ b/backend/internal/repository/usage_cleanup_repo_test.go @@ -466,6 +466,38 @@ func TestBuildUsageCleanupWhere(t *testing.T) { require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args) } +func TestBuildUsageCleanupWhereRequestTypePriority(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeWSV2) + stream := false + + where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + RequestType: &requestType, + Stream: &stream, + }) + + require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", where) + require.Equal(t, []any{start, end, requestType}, args) +} + +func TestBuildUsageCleanupWhereRequestTypeLegacyFallback(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeStream) + + where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + RequestType: &requestType, + }) + + require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", where) + require.Equal(t, []any{start, end, requestType}, args) +} + func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) { start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) end := start.Add(24 * time.Hour) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 2db1764f..d30cc7dd 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,23 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at" + +// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL +var dateFormatWhitelist = map[string]string{ + "hour": "YYYY-MM-DD HH24:00", + "day": "YYYY-MM-DD", + "week": "IYYY-IW", + "month": "YYYY-MM", +} + +// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值 +func safeDateFormat(granularity string) string { + if f, ok := dateFormatWhitelist[granularity]; ok { + return f + } + return "YYYY-MM-DD" +} type usageLogRepository struct { client *dbent.Client @@ -82,6 +98,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) log.RequestID = requestID rateMultiplier := log.RateMultiplier + log.SyncRequestTypeAndLegacyFields() + requestType := int16(log.RequestType) query := ` INSERT INTO usage_logs ( @@ -107,26 +125,30 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) rate_multiplier, account_rate_multiplier, billing_type, + request_type, stream, + openai_ws_mode, duration_ms, first_token_ms, user_agent, - ip_address, - image_count, - image_size, - reasoning_effort, - created_at - ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31 - ) - ON CONFLICT (request_id, api_key_id) DO NOTHING - RETURNING id, created_at - ` + ip_address, + image_count, + image_size, + media_type, + reasoning_effort, + cache_ttl_overridden, + created_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, + $8, $9, $10, $11, + $12, $13, + $14, $15, $16, $17, $18, $19, + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35 + ) + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING id, created_at + ` groupID := nullInt64(log.GroupID) subscriptionID := nullInt64(log.SubscriptionID) @@ -135,6 +157,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) userAgent := nullString(log.UserAgent) ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) + mediaType := nullString(log.MediaType) reasoningEffort := nullString(log.ReasoningEffort) var requestIDArg any @@ -165,14 +188,18 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) rateMultiplier, log.AccountRateMultiplier, log.BillingType, + requestType, log.Stream, + log.OpenAIWSMode, duration, firstToken, userAgent, ipAddress, log.ImageCount, imageSize, + mediaType, reasoningEffort, + log.CacheTTLOverridden, createdAt, } if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { @@ -471,25 +498,46 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte } func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error { - totalStatsQuery := ` + todayEnd := todayUTC.Add(24 * time.Hour) + combinedStatsQuery := ` + WITH scoped AS ( + SELECT + created_at, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + total_cost, + actual_cost, + COALESCE(duration_ms, 0) AS duration_ms + FROM usage_logs + WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz) + AND created_at < GREATEST($2::timestamptz, $4::timestamptz) + ) SELECT - COUNT(*) as total_requests, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as total_cost, - COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 + COUNT(*) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz) AS total_requests, + COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_input_tokens, + COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_output_tokens, + COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens, + COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost, + COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms, + COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests, + COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens, + COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_output_tokens, + COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens, + COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost + FROM scoped ` var totalDurationMs int64 if err := scanSingleRow( ctx, r.sql, - totalStatsQuery, - []any{startUTC, endUTC}, + combinedStatsQuery, + []any{startUTC, endUTC, todayUTC, todayEnd}, &stats.TotalRequests, &stats.TotalInputTokens, &stats.TotalOutputTokens, @@ -498,32 +546,6 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co &stats.TotalCost, &stats.TotalActualCost, &totalDurationMs, - ); err != nil { - return err - } - stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens - if stats.TotalRequests > 0 { - stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) - } - - todayEnd := todayUTC.Add(24 * time.Hour) - todayStatsQuery := ` - SELECT - COUNT(*) as today_requests, - COALESCE(SUM(input_tokens), 0) as today_input_tokens, - COALESCE(SUM(output_tokens), 0) as today_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as today_cost, - COALESCE(SUM(actual_cost), 0) as today_actual_cost - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 - ` - if err := scanSingleRow( - ctx, - r.sql, - todayStatsQuery, - []any{todayUTC, todayEnd}, &stats.TodayRequests, &stats.TodayInputTokens, &stats.TodayOutputTokens, @@ -534,25 +556,28 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co ); err != nil { return err } - stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens - - activeUsersQuery := ` - SELECT COUNT(DISTINCT user_id) as active_users - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 - ` - if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil { - return err + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens + if stats.TotalRequests > 0 { + stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) } + stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens + hourStart := now.UTC().Truncate(time.Hour) hourEnd := hourStart.Add(time.Hour) - hourlyActiveQuery := ` - SELECT COUNT(DISTINCT user_id) as active_users - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 + activeUsersQuery := ` + WITH scoped AS ( + SELECT user_id, created_at + FROM usage_logs + WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz) + AND created_at < GREATEST($2::timestamptz, $4::timestamptz) + ) + SELECT + COUNT(DISTINCT CASE WHEN created_at >= $1::timestamptz AND created_at < $2::timestamptz THEN user_id END) AS active_users, + COUNT(DISTINCT CASE WHEN created_at >= $3::timestamptz AND created_at < $4::timestamptz THEN user_id END) AS hourly_active_users + FROM scoped ` - if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil { + if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd, hourStart, hourEnd}, &stats.ActiveUsers, &stats.HourlyActiveUsers); err != nil { return err } @@ -564,7 +589,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, } func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime) return logs, nil, err } @@ -810,19 +835,19 @@ func resolveUsageStatsTimezone() string { } func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) return logs, nil, err } func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime) return logs, nil, err } func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) return logs, nil, err } @@ -894,6 +919,114 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI return stats, nil } +// GetAccountWindowStatsBatch 批量获取同一窗口起点下多个账号的统计数据。 +// 返回 map[accountID]*AccountStats,未命中的账号会返回零值统计,便于上层直接复用。 +func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) { + result := make(map[int64]*usagestats.AccountStats, len(accountIDs)) + if len(accountIDs) == 0 { + return result, nil + } + + query := ` + SELECT + account_id, + COUNT(*) as requests, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, + COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, + COALESCE(SUM(total_cost), 0) as standard_cost, + COALESCE(SUM(actual_cost), 0) as user_cost + FROM usage_logs + WHERE account_id = ANY($1) AND created_at >= $2 + GROUP BY account_id + ` + rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var accountID int64 + stats := &usagestats.AccountStats{} + if err := rows.Scan( + &accountID, + &stats.Requests, + &stats.Tokens, + &stats.Cost, + &stats.StandardCost, + &stats.UserCost, + ); err != nil { + return nil, err + } + result[accountID] = stats + } + if err := rows.Err(); err != nil { + return nil, err + } + + for _, accountID := range accountIDs { + if _, ok := result[accountID]; !ok { + result[accountID] = &usagestats.AccountStats{} + } + } + return result, nil +} + +// GetGeminiUsageTotalsBatch 批量聚合 Gemini 账号在窗口内的 Pro/Flash 请求与用量。 +// 模型分类规则与 service.geminiModelClassFromName 一致:model 包含 flash/lite 视为 flash,其余视为 pro。 +func (r *usageLogRepository) GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]service.GeminiUsageTotals, error) { + result := make(map[int64]service.GeminiUsageTotals, len(accountIDs)) + if len(accountIDs) == 0 { + return result, nil + } + + query := ` + SELECT + account_id, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 1 ELSE 0 END), 0) AS flash_requests, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE 1 END), 0) AS pro_requests, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) ELSE 0 END), 0) AS flash_tokens, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) END), 0) AS pro_tokens, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN actual_cost ELSE 0 END), 0) AS flash_cost, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE actual_cost END), 0) AS pro_cost + FROM usage_logs + WHERE account_id = ANY($1) AND created_at >= $2 AND created_at < $3 + GROUP BY account_id + ` + rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime, endTime) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var accountID int64 + var totals service.GeminiUsageTotals + if err := rows.Scan( + &accountID, + &totals.FlashRequests, + &totals.ProRequests, + &totals.FlashTokens, + &totals.ProTokens, + &totals.FlashCost, + &totals.ProCost, + ); err != nil { + return nil, err + } + result[accountID] = totals + } + if err := rows.Err(); err != nil { + return nil, err + } + + for _, accountID := range accountIDs { + if _, ok := result[accountID]; !ok { + result[accountID] = service.GeminiUsageTotals{} + } + } + return result, nil +} + // TrendDataPoint represents a single point in trend data type TrendDataPoint = usagestats.TrendDataPoint @@ -908,10 +1041,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint // GetAPIKeyUsageTrend returns usage trend data grouped by API key and date func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` WITH top_keys AS ( @@ -966,10 +1096,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, // GetUserUsageTrend returns usage trend data grouped by user and date func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` WITH top_users AS ( @@ -1228,10 +1355,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey // GetUserUsageTrendByUserID 获取指定用户的使用趋势 func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` SELECT @@ -1334,10 +1458,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) args = append(args, filters.Model) } - if filters.Stream != nil { - conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) - args = append(args, *filters.Stream) - } + conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) @@ -1369,13 +1490,22 @@ type UsageStats = usagestats.UsageStats // BatchUserUsageStats represents usage stats for a single user type BatchUserUsageStats = usagestats.BatchUserUsageStats -// GetBatchUserUsageStats gets today and total actual_cost for multiple users -func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { +// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range. +// If startTime is zero, defaults to 30 days ago. +func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) { result := make(map[int64]*BatchUserUsageStats) if len(userIDs) == 0 { return result, nil } + // 默认最近 30 天 + if startTime.IsZero() { + startTime = time.Now().AddDate(0, 0, -30) + } + if endTime.IsZero() { + endTime = time.Now() + } + for _, id := range userIDs { result[id] = &BatchUserUsageStats{UserID: id} } @@ -1383,10 +1513,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs query := ` SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost FROM usage_logs - WHERE user_id = ANY($1) + WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3 GROUP BY user_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs)) + rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime) if err != nil { return nil, err } @@ -1443,13 +1573,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs // BatchAPIKeyUsageStats represents usage stats for a single API key type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats -// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys -func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) { +// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range. +// If startTime is zero, defaults to 30 days ago. +func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) { result := make(map[int64]*BatchAPIKeyUsageStats) if len(apiKeyIDs) == 0 { return result, nil } + // 默认最近 30 天 + if startTime.IsZero() { + startTime = time.Now().AddDate(0, 0, -30) + } + if endTime.IsZero() { + endTime = time.Now() + } + for _, id := range apiKeyIDs { result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} } @@ -1457,10 +1596,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe query := ` SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost FROM usage_logs - WHERE api_key_id = ANY($1) + WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3 GROUP BY api_key_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs)) + rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime) if err != nil { return nil, err } @@ -1515,11 +1654,8 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe } // GetUsageTrendWithFilters returns usage trend data with optional filters -func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } +func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` SELECT @@ -1556,10 +1692,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND model = $%d", len(args)+1) args = append(args, model) } - if stream != nil { - query += fmt.Sprintf(" AND stream = $%d", len(args)+1) - args = append(args, *stream) - } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) @@ -1587,7 +1720,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start } // GetModelStatsWithFilters returns model statistics with optional filters -func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) { +func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { @@ -1624,10 +1757,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if stream != nil { - query += fmt.Sprintf(" AND stream = $%d", len(args)+1) - args = append(args, *stream) - } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) @@ -1654,6 +1784,77 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start return results, nil } +// GetGroupStatsWithFilters returns group usage statistics with optional filters +func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []usagestats.GroupStat, err error) { + query := ` + SELECT + COALESCE(ul.group_id, 0) as group_id, + COALESCE(g.name, '') as group_name, + COUNT(*) as requests, + COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, + COALESCE(SUM(ul.total_cost), 0) as cost, + COALESCE(SUM(ul.actual_cost), 0) as actual_cost + FROM usage_logs ul + LEFT JOIN groups g ON g.id = ul.group_id + WHERE ul.created_at >= $1 AND ul.created_at < $2 + ` + + args := []any{startTime, endTime} + if userID > 0 { + query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1) + args = append(args, userID) + } + if apiKeyID > 0 { + query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1) + args = append(args, apiKeyID) + } + if accountID > 0 { + query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1) + args = append(args, accountID) + } + if groupID > 0 { + query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1) + args = append(args, groupID) + } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) + if billingType != nil { + query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1) + args = append(args, int16(*billingType)) + } + query += " GROUP BY ul.group_id, g.name ORDER BY total_tokens DESC" + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results = make([]usagestats.GroupStat, 0) + for rows.Next() { + var row usagestats.GroupStat + if err := rows.Scan( + &row.GroupID, + &row.GroupName, + &row.Requests, + &row.TotalTokens, + &row.Cost, + &row.ActualCost, + ); err != nil { + return nil, err + } + results = append(results, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} + // GetGlobalStats gets usage statistics for all users within a time range func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) { query := ` @@ -1714,10 +1915,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) args = append(args, filters.Model) } - if filters.Stream != nil { - conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) - args = append(args, *filters.Stream) - } + conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) @@ -1937,7 +2135,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID } } - models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil) + models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil, nil) if err != nil { models = []ModelStat{} } @@ -2187,14 +2385,18 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e rateMultiplier float64 accountRateMultiplier sql.NullFloat64 billingType int16 + requestTypeRaw int16 stream bool + openaiWSMode bool durationMs sql.NullInt64 firstTokenMs sql.NullInt64 userAgent sql.NullString ipAddress sql.NullString imageCount int imageSize sql.NullString + mediaType sql.NullString reasoningEffort sql.NullString + cacheTTLOverridden bool createdAt time.Time ) @@ -2222,14 +2424,18 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &rateMultiplier, &accountRateMultiplier, &billingType, + &requestTypeRaw, &stream, + &openaiWSMode, &durationMs, &firstTokenMs, &userAgent, &ipAddress, &imageCount, &imageSize, + &mediaType, &reasoningEffort, + &cacheTTLOverridden, &createdAt, ); err != nil { return nil, err @@ -2256,10 +2462,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e RateMultiplier: rateMultiplier, AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier), BillingType: int8(billingType), - Stream: stream, + RequestType: service.RequestTypeFromInt16(requestTypeRaw), ImageCount: imageCount, + CacheTTLOverridden: cacheTTLOverridden, CreatedAt: createdAt, } + // 先回填 legacy 字段,再基于 legacy + request_type 计算最终请求类型,保证历史数据兼容。 + log.Stream = stream + log.OpenAIWSMode = openaiWSMode + log.RequestType = log.EffectiveRequestType() + log.Stream, log.OpenAIWSMode = service.ApplyLegacyRequestFields(log.RequestType, stream, openaiWSMode) if requestID.Valid { log.RequestID = requestID.String @@ -2289,6 +2501,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if imageSize.Valid { log.ImageSize = &imageSize.String } + if mediaType.Valid { + log.MediaType = &mediaType.String + } if reasoningEffort.Valid { log.ReasoningEffort = &reasoningEffort.String } @@ -2350,6 +2565,50 @@ func buildWhere(conditions []string) string { return "WHERE " + strings.Join(conditions, " AND ") } +func appendRequestTypeOrStreamWhereCondition(conditions []string, args []any, requestType *int16, stream *bool) ([]string, []any) { + if requestType != nil { + condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType) + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + return conditions, args + } + if stream != nil { + conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) + args = append(args, *stream) + } + return conditions, args +} + +func appendRequestTypeOrStreamQueryFilter(query string, args []any, requestType *int16, stream *bool) (string, []any) { + if requestType != nil { + condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType) + query += " AND " + condition + args = append(args, conditionArgs...) + return query, args + } + if stream != nil { + query += fmt.Sprintf(" AND stream = $%d", len(args)+1) + args = append(args, *stream) + } + return query, args +} + +// buildRequestTypeFilterCondition 在 request_type 过滤时兼容 legacy 字段,避免历史数据漏查。 +func buildRequestTypeFilterCondition(startArgIndex int, requestType int16) (string, []any) { + normalized := service.RequestTypeFromInt16(requestType) + requestTypeArg := int16(normalized) + switch normalized { + case service.RequestTypeSync: + return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = FALSE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg} + case service.RequestTypeStream: + return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = TRUE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg} + case service.RequestTypeWSV2: + return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND openai_ws_mode = TRUE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg} + default: + return fmt.Sprintf("request_type = $%d", startArgIndex), []any{requestTypeArg} + } +} + func nullInt64(v *int64) sql.NullInt64 { if v == nil { return sql.NullInt64{} diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index eb220f22..4d50f7de 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -130,6 +130,62 @@ func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() { s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001) } +func (s *UsageLogRepoSuite) TestGetByID_ReturnsOpenAIWSMode() { + user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-ws@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-ws", Name: "k"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-ws"}) + + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.New().String(), + Model: "gpt-5.3-codex", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1.0, + ActualCost: 1.0, + OpenAIWSMode: true, + CreatedAt: timezone.Today().Add(3 * time.Hour), + } + _, err := s.repo.Create(s.ctx, log) + s.Require().NoError(err) + + got, err := s.repo.GetByID(s.ctx, log.ID) + s.Require().NoError(err) + s.Require().True(got.OpenAIWSMode) +} + +func (s *UsageLogRepoSuite) TestGetByID_ReturnsRequestTypeAndLegacyFallback() { + user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-request-type@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-request-type", Name: "k"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-request-type"}) + + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.New().String(), + Model: "gpt-5.3-codex", + RequestType: service.RequestTypeWSV2, + Stream: true, + OpenAIWSMode: false, + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1.0, + ActualCost: 1.0, + CreatedAt: timezone.Today().Add(4 * time.Hour), + } + _, err := s.repo.Create(s.ctx, log) + s.Require().NoError(err) + + got, err := s.repo.GetByID(s.ctx, log.ID) + s.Require().NoError(err) + s.Require().Equal(service.RequestTypeWSV2, got.RequestType) + s.Require().True(got.Stream) + s.Require().True(got.OpenAIWSMode) +} + // --- Delete --- func (s *UsageLogRepoSuite) TestDelete() { @@ -648,7 +704,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now()) - stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}) + stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{}) s.Require().NoError(err, "GetBatchUserUsageStats") s.Require().Len(stats, 2) s.Require().NotNil(stats[user1.ID]) @@ -656,7 +712,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { } func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() { - stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}) + stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{}) s.Require().NoError(err) s.Require().Empty(stats) } @@ -672,13 +728,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) - stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}) + stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{}) s.Require().NoError(err, "GetBatchAPIKeyUsageStats") s.Require().Len(stats, 2) } func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { - stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}) + stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{}) s.Require().NoError(err) s.Require().Empty(stats) } @@ -944,17 +1000,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { endTime := base.Add(48 * time.Hour) // Test with user filter - trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil) + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters user filter") s.Require().Len(trend, 2) // Test with apiKey filter - trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil) + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter") s.Require().Len(trend, 2) // Test with both filters - trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil) + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters both filters") s.Require().Len(trend, 2) } @@ -971,7 +1027,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { startTime := base.Add(-1 * time.Hour) endTime := base.Add(3 * time.Hour) - trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil) + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters hourly") s.Require().Len(trend, 2) } @@ -1017,17 +1073,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { endTime := base.Add(2 * time.Hour) // Test with user filter - stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil) + stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters user filter") s.Require().Len(stats, 2) // Test with apiKey filter - stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil) + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter") s.Require().Len(stats, 2) // Test with account filter - stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil) + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters account filter") s.Require().Len(stats, 2) } diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go new file mode 100644 index 00000000..95cf2a2d --- /dev/null +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -0,0 +1,327 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + log := &service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-1", + Model: "gpt-5", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1, + ActualCost: 1, + BillingType: service.BillingTypeBalance, + RequestType: service.RequestTypeWSV2, + Stream: false, + OpenAIWSMode: false, + CreatedAt: createdAt, + } + + mock.ExpectQuery("INSERT INTO usage_logs"). + WithArgs( + log.UserID, + log.APIKeyID, + log.AccountID, + log.RequestID, + log.Model, + sqlmock.AnyArg(), // group_id + sqlmock.AnyArg(), // subscription_id + log.InputTokens, + log.OutputTokens, + log.CacheCreationTokens, + log.CacheReadTokens, + log.CacheCreation5mTokens, + log.CacheCreation1hTokens, + log.InputCost, + log.OutputCost, + log.CacheCreationCost, + log.CacheReadCost, + log.TotalCost, + log.ActualCost, + log.RateMultiplier, + log.AccountRateMultiplier, + log.BillingType, + int16(service.RequestTypeWSV2), + true, + true, + sqlmock.AnyArg(), // duration_ms + sqlmock.AnyArg(), // first_token_ms + sqlmock.AnyArg(), // user_agent + sqlmock.AnyArg(), // ip_address + log.ImageCount, + sqlmock.AnyArg(), // image_size + sqlmock.AnyArg(), // media_type + sqlmock.AnyArg(), // reasoning_effort + log.CacheTTLOverridden, + createdAt, + ). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) + + inserted, err := repo.Create(context.Background(), log) + require.NoError(t, err) + require.True(t, inserted) + require.Equal(t, int64(99), log.ID) + require.Equal(t, service.RequestTypeWSV2, log.RequestType) + require.True(t, log.Stream) + require.True(t, log.OpenAIWSMode) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + requestType := int16(service.RequestTypeWSV2) + stream := false + filters := usagestats.UsageLogFilters{ + RequestType: &requestType, + Stream: &stream, + } + + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). + WithArgs(requestType). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0))) + mock.ExpectQuery("SELECT .* FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\) ORDER BY id DESC LIMIT \\$2 OFFSET \\$3"). + WithArgs(requestType, 20, 0). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + logs, page, err := repo.ListWithFilters(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}, filters) + require.NoError(t, err) + require.Empty(t, logs) + require.NotNil(t, page) + require.Equal(t, int64(0), page.Total) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeStream) + stream := true + + mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)"). + WithArgs(start, end, requestType). + WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"})) + + trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil) + require.NoError(t, err) + require.Empty(t, trend) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeWSV2) + stream := false + + mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). + WithArgs(start, end, requestType). + WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"})) + + stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil) + require.NoError(t, err) + require.Empty(t, stats) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + requestType := int16(service.RequestTypeSync) + stream := true + filters := usagestats.UsageLogFilters{ + RequestType: &requestType, + Stream: &stream, + } + + mock.ExpectQuery("FROM usage_logs\\s+WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE\\)\\)"). + WithArgs(requestType). + WillReturnRows(sqlmock.NewRows([]string{ + "total_requests", + "total_input_tokens", + "total_output_tokens", + "total_cache_tokens", + "total_cost", + "total_actual_cost", + "total_account_cost", + "avg_duration_ms", + }).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0)) + + stats, err := repo.GetStatsWithFilters(context.Background(), filters) + require.NoError(t, err) + require.Equal(t, int64(1), stats.TotalRequests) + require.Equal(t, int64(9), stats.TotalTokens) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) { + tests := []struct { + name string + request int16 + wantWhere string + wantArg int16 + }{ + { + name: "sync_with_legacy_fallback", + request: int16(service.RequestTypeSync), + wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE))", + wantArg: int16(service.RequestTypeSync), + }, + { + name: "stream_with_legacy_fallback", + request: int16(service.RequestTypeStream), + wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", + wantArg: int16(service.RequestTypeStream), + }, + { + name: "ws_v2_with_legacy_fallback", + request: int16(service.RequestTypeWSV2), + wantWhere: "(request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", + wantArg: int16(service.RequestTypeWSV2), + }, + { + name: "invalid_request_type_normalized_to_unknown", + request: int16(99), + wantWhere: "request_type = $3", + wantArg: int16(service.RequestTypeUnknown), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + where, args := buildRequestTypeFilterCondition(3, tt.request) + require.Equal(t, tt.wantWhere, where) + require.Equal(t, []any{tt.wantArg}, args) + }) + } +} + +type usageLogScannerStub struct { + values []any +} + +func (s usageLogScannerStub) Scan(dest ...any) error { + if len(dest) != len(s.values) { + return fmt.Errorf("scan arg count mismatch: got %d want %d", len(dest), len(s.values)) + } + for i := range dest { + dv := reflect.ValueOf(dest[i]) + if dv.Kind() != reflect.Ptr { + return fmt.Errorf("dest[%d] is not pointer", i) + } + dv.Elem().Set(reflect.ValueOf(s.values[i])) + } + return nil +} + +func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { + t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) { + now := time.Now().UTC() + log, err := scanUsageLog(usageLogScannerStub{values: []any{ + int64(1), // id + int64(10), // user_id + int64(20), // api_key_id + int64(30), // account_id + sql.NullString{Valid: true, String: "req-1"}, + "gpt-5", // model + sql.NullInt64{}, // group_id + sql.NullInt64{}, // subscription_id + 1, // input_tokens + 2, // output_tokens + 3, // cache_creation_tokens + 4, // cache_read_tokens + 5, // cache_creation_5m_tokens + 6, // cache_creation_1h_tokens + 0.1, // input_cost + 0.2, // output_cost + 0.3, // cache_creation_cost + 0.4, // cache_read_cost + 1.0, // total_cost + 0.9, // actual_cost + 1.0, // rate_multiplier + sql.NullFloat64{}, // account_rate_multiplier + int16(service.BillingTypeBalance), + int16(service.RequestTypeWSV2), + false, // legacy stream + false, // legacy openai ws + sql.NullInt64{}, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + 0, + sql.NullString{}, + sql.NullString{}, + sql.NullString{}, + false, + now, + }}) + require.NoError(t, err) + require.Equal(t, service.RequestTypeWSV2, log.RequestType) + require.True(t, log.Stream) + require.True(t, log.OpenAIWSMode) + }) + + t.Run("request_type_unknown_falls_back_to_legacy", func(t *testing.T) { + now := time.Now().UTC() + log, err := scanUsageLog(usageLogScannerStub{values: []any{ + int64(2), + int64(11), + int64(21), + int64(31), + sql.NullString{Valid: true, String: "req-2"}, + "gpt-5", + sql.NullInt64{}, + sql.NullInt64{}, + 1, 2, 3, 4, 5, 6, + 0.1, 0.2, 0.3, 0.4, 1.0, 0.9, + 1.0, + sql.NullFloat64{}, + int16(service.BillingTypeBalance), + int16(service.RequestTypeUnknown), + true, + false, + sql.NullInt64{}, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + 0, + sql.NullString{}, + sql.NullString{}, + sql.NullString{}, + false, + now, + }}) + require.NoError(t, err) + require.Equal(t, service.RequestTypeStream, log.RequestType) + require.True(t, log.Stream) + require.False(t, log.OpenAIWSMode) + }) +} diff --git a/backend/internal/repository/usage_log_repo_unit_test.go b/backend/internal/repository/usage_log_repo_unit_test.go new file mode 100644 index 00000000..d0e14ffd --- /dev/null +++ b/backend/internal/repository/usage_log_repo_unit_test.go @@ -0,0 +1,41 @@ +//go:build unit + +package repository + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSafeDateFormat(t *testing.T) { + tests := []struct { + name string + granularity string + expected string + }{ + // 合法值 + {"hour", "hour", "YYYY-MM-DD HH24:00"}, + {"day", "day", "YYYY-MM-DD"}, + {"week", "week", "IYYY-IW"}, + {"month", "month", "YYYY-MM"}, + + // 非法值回退到默认 + {"空字符串", "", "YYYY-MM-DD"}, + {"未知粒度 year", "year", "YYYY-MM-DD"}, + {"未知粒度 minute", "minute", "YYYY-MM-DD"}, + + // 恶意字符串 + {"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"}, + {"带引号", "day'", "YYYY-MM-DD"}, + {"带括号", "day)", "YYYY-MM-DD"}, + {"Unicode", "日", "YYYY-MM-DD"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := safeDateFormat(tc.granularity) + require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity) + }) + } +} diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go index eb65403b..e3b11096 100644 --- a/backend/internal/repository/user_group_rate_repo.go +++ b/backend/internal/repository/user_group_rate_repo.go @@ -6,6 +6,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" ) type userGroupRateRepository struct { @@ -41,6 +42,59 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) return result, nil } +// GetByUserIDs 批量获取多个用户的专属分组倍率。 +// 返回结构:map[userID]map[groupID]rate +func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) { + result := make(map[int64]map[int64]float64, len(userIDs)) + if len(userIDs) == 0 { + return result, nil + } + + uniqueIDs := make([]int64, 0, len(userIDs)) + seen := make(map[int64]struct{}, len(userIDs)) + for _, userID := range userIDs { + if userID <= 0 { + continue + } + if _, exists := seen[userID]; exists { + continue + } + seen[userID] = struct{}{} + uniqueIDs = append(uniqueIDs, userID) + result[userID] = make(map[int64]float64) + } + if len(uniqueIDs) == 0 { + return result, nil + } + + rows, err := r.sql.QueryContext(ctx, ` + SELECT user_id, group_id, rate_multiplier + FROM user_group_rate_multipliers + WHERE user_id = ANY($1) + `, pq.Array(uniqueIDs)) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var userID int64 + var groupID int64 + var rate float64 + if err := rows.Scan(&userID, &groupID, &rate); err != nil { + return nil, err + } + if _, ok := result[userID]; !ok { + result[userID] = make(map[int64]float64) + } + result[userID][groupID] = rate + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + // GetByUserAndGroup 获取用户在特定分组的专属倍率 func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` @@ -65,33 +119,43 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID // 分离需要删除和需要 upsert 的记录 var toDelete []int64 - toUpsert := make(map[int64]float64) + upsertGroupIDs := make([]int64, 0, len(rates)) + upsertRates := make([]float64, 0, len(rates)) for groupID, rate := range rates { if rate == nil { toDelete = append(toDelete, groupID) } else { - toUpsert[groupID] = *rate + upsertGroupIDs = append(upsertGroupIDs, groupID) + upsertRates = append(upsertRates, *rate) } } // 删除指定的记录 - for _, groupID := range toDelete { - _, err := r.sql.ExecContext(ctx, - `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`, - userID, groupID) - if err != nil { + if len(toDelete) > 0 { + if _, err := r.sql.ExecContext(ctx, + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`, + userID, pq.Array(toDelete)); err != nil { return err } } // Upsert 记录 now := time.Now() - for groupID, rate := range toUpsert { + if len(upsertGroupIDs) > 0 { _, err := r.sql.ExecContext(ctx, ` INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) - VALUES ($1, $2, $3, $4, $4) - ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4 - `, userID, groupID, rate, now) + SELECT + $1::bigint, + data.group_id, + data.rate_multiplier, + $2::timestamptz, + $2::timestamptz + FROM unnest($3::bigint[], $4::double precision[]) AS data(group_id, rate_multiplier) + ON CONFLICT (user_id, group_id) + DO UPDATE SET + rate_multiplier = EXCLUDED.rate_multiplier, + updated_at = EXCLUDED.updated_at + `, userID, now, pq.Array(upsertGroupIDs), pq.Array(upsertRates)) if err != nil { return err } diff --git a/backend/internal/repository/user_msg_queue_cache.go b/backend/internal/repository/user_msg_queue_cache.go new file mode 100644 index 00000000..bb3ee698 --- /dev/null +++ b/backend/internal/repository/user_msg_queue_cache.go @@ -0,0 +1,186 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +// Redis Key 模式(使用 hash tag 确保 Redis Cluster 下同一 accountID 的 key 落入同一 slot) +// 格式: umq:{accountID}:lock / umq:{accountID}:last +const ( + umqKeyPrefix = "umq:" + umqLockSuffix = ":lock" // STRING (requestID), PX lockTtlMs + umqLastSuffix = ":last" // STRING (毫秒时间戳), EX 60s +) + +// Lua 脚本:原子获取串行锁(SET NX PX + 重入安全) +var acquireLockScript = redis.NewScript(` +local cur = redis.call('GET', KEYS[1]) +if cur == ARGV[1] then + redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[2])) + return 1 +end +if cur ~= false then return 0 end +redis.call('SET', KEYS[1], ARGV[1], 'PX', tonumber(ARGV[2])) +return 1 +`) + +// Lua 脚本:原子释放锁 + 记录完成时间(使用 Redis TIME 避免时钟偏差) +var releaseLockScript = redis.NewScript(` +local cur = redis.call('GET', KEYS[1]) +if cur == ARGV[1] then + redis.call('DEL', KEYS[1]) + local t = redis.call('TIME') + local ms = tonumber(t[1])*1000 + math.floor(tonumber(t[2])/1000) + redis.call('SET', KEYS[2], ms, 'EX', 60) + return 1 +end +return 0 +`) + +// Lua 脚本:原子清理孤儿锁(仅在 PTTL == -1 时删除,避免 TOCTOU 竞态误删合法锁) +var forceReleaseLockScript = redis.NewScript(` +local pttl = redis.call('PTTL', KEYS[1]) +if pttl == -1 then + redis.call('DEL', KEYS[1]) + return 1 +end +return 0 +`) + +type userMsgQueueCache struct { + rdb *redis.Client +} + +// NewUserMsgQueueCache 创建用户消息队列缓存 +func NewUserMsgQueueCache(rdb *redis.Client) service.UserMsgQueueCache { + return &userMsgQueueCache{rdb: rdb} +} + +func umqLockKey(accountID int64) string { + // 格式: umq:{123}:lock — 花括号确保 Redis Cluster hash tag 生效 + return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLockSuffix +} + +func umqLastKey(accountID int64) string { + // 格式: umq:{123}:last — 与 lockKey 同一 hash slot + return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLastSuffix +} + +// umqScanPattern 用于 SCAN 扫描锁 key +func umqScanPattern() string { + return umqKeyPrefix + "{*}" + umqLockSuffix +} + +// AcquireLock 尝试获取账号级串行锁 +func (c *userMsgQueueCache) AcquireLock(ctx context.Context, accountID int64, requestID string, lockTtlMs int) (bool, error) { + key := umqLockKey(accountID) + result, err := acquireLockScript.Run(ctx, c.rdb, []string{key}, requestID, lockTtlMs).Int() + if err != nil { + return false, fmt.Errorf("umq acquire lock: %w", err) + } + return result == 1, nil +} + +// ReleaseLock 释放锁并记录完成时间 +func (c *userMsgQueueCache) ReleaseLock(ctx context.Context, accountID int64, requestID string) (bool, error) { + lockKey := umqLockKey(accountID) + lastKey := umqLastKey(accountID) + result, err := releaseLockScript.Run(ctx, c.rdb, []string{lockKey, lastKey}, requestID).Int() + if err != nil { + return false, fmt.Errorf("umq release lock: %w", err) + } + return result == 1, nil +} + +// GetLastCompletedMs 获取上次完成时间(毫秒时间戳) +func (c *userMsgQueueCache) GetLastCompletedMs(ctx context.Context, accountID int64) (int64, error) { + key := umqLastKey(accountID) + val, err := c.rdb.Get(ctx, key).Result() + if errors.Is(err, redis.Nil) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("umq get last completed: %w", err) + } + ms, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return 0, fmt.Errorf("umq parse last completed: %w", err) + } + return ms, nil +} + +// ForceReleaseLock 原子清理孤儿锁(仅在 PTTL == -1 时删除,防止 TOCTOU 竞态误删合法锁) +func (c *userMsgQueueCache) ForceReleaseLock(ctx context.Context, accountID int64) error { + key := umqLockKey(accountID) + _, err := forceReleaseLockScript.Run(ctx, c.rdb, []string{key}).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return fmt.Errorf("umq force release lock: %w", err) + } + return nil +} + +// ScanLockKeys 扫描所有锁 key,仅返回 PTTL == -1(无过期时间)的孤儿锁 accountID 列表 +// 正常的锁都有 PX 过期时间,PTTL == -1 表示异常状态(如 Redis 故障恢复后丢失 TTL) +func (c *userMsgQueueCache) ScanLockKeys(ctx context.Context, maxCount int) ([]int64, error) { + var accountIDs []int64 + var cursor uint64 + pattern := umqScanPattern() + + for { + keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, 100).Result() + if err != nil { + return nil, fmt.Errorf("umq scan lock keys: %w", err) + } + for _, key := range keys { + // 检查 PTTL:只清理 PTTL == -1(无过期时间)的异常锁 + pttl, err := c.rdb.PTTL(ctx, key).Result() + if err != nil { + continue + } + // PTTL 返回值:-2 = key 不存在,-1 = 无过期时间,>0 = 剩余毫秒 + // go-redis 对哨兵值 -1/-2 不乘精度系数,直接返回 time.Duration(-1)/-2 + // 只删除 -1(无过期时间的异常锁),跳过正常持有的锁 + if pttl != time.Duration(-1) { + continue + } + + // 从 key 中提取 accountID: umq:{123}:lock → 提取 {} 内的数字 + openBrace := strings.IndexByte(key, '{') + closeBrace := strings.IndexByte(key, '}') + if openBrace < 0 || closeBrace <= openBrace+1 { + continue + } + idStr := key[openBrace+1 : closeBrace] + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + continue + } + accountIDs = append(accountIDs, id) + if len(accountIDs) >= maxCount { + return accountIDs, nil + } + } + cursor = nextCursor + if cursor == 0 { + break + } + } + return accountIDs, nil +} + +// GetCurrentTimeMs 通过 Redis TIME 命令获取当前服务器时间(毫秒),确保与锁记录的时间源一致 +func (c *userMsgQueueCache) GetCurrentTimeMs(ctx context.Context) (int64, error) { + t, err := c.rdb.Time(ctx).Result() + if err != nil { + return 0, fmt.Errorf("umq get redis time: %w", err) + } + return t.UnixMilli(), nil +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 654bd16b..05b68968 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -10,6 +10,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/apikey" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -60,6 +61,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). + SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes). Save(ctx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) @@ -142,6 +144,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). + SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes). + SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes). Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) @@ -191,6 +195,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. dbuser.EmailContainsFold(filters.Search), dbuser.UsernameContainsFold(filters.Search), dbuser.NotesContainsFold(filters.Search), + dbuser.HasAPIKeysWith(apikey.KeyContainsFold(filters.Search)), ), ) } @@ -361,10 +366,79 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount return nil } +// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。 +func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) { + if deltaBytes <= 0 { + user, err := r.GetByID(ctx, userID) + if err != nil { + return 0, err + } + return user.SoraStorageUsedBytes, nil + } + var newUsed int64 + err := scanSingleRow(ctx, r.sql, ` + UPDATE users + SET sora_storage_used_bytes = sora_storage_used_bytes + $2 + WHERE id = $1 + AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3) + RETURNING sora_storage_used_bytes + `, []any{userID, deltaBytes, effectiveQuota}, &newUsed) + if err == nil { + return newUsed, nil + } + if errors.Is(err, sql.ErrNoRows) { + // 区分用户不存在和配额冲突 + exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx) + if existsErr != nil { + return 0, existsErr + } + if !exists { + return 0, service.ErrUserNotFound + } + return 0, service.ErrSoraStorageQuotaExceeded + } + return 0, err +} + +// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。 +func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) { + if deltaBytes <= 0 { + user, err := r.GetByID(ctx, userID) + if err != nil { + return 0, err + } + return user.SoraStorageUsedBytes, nil + } + var newUsed int64 + err := scanSingleRow(ctx, r.sql, ` + UPDATE users + SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0) + WHERE id = $1 + RETURNING sora_storage_used_bytes + `, []any{userID, deltaBytes}, &newUsed) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, service.ErrUserNotFound + } + return 0, err + } + return newUsed, nil +} + func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) } +func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + client := clientFromContext(ctx, r.client) + return client.UserAllowedGroup.Create(). + SetUserID(userID). + SetGroupID(groupID). + OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID). + DoNothing(). + Exec(ctx) +} + func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { // 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。 affected, err := r.client.UserAllowedGroup.Delete(). diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 3aed9d9c..2e35e0a0 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -28,13 +28,13 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc // ProvideGitHubReleaseClient 创建 GitHub Release 客户端 // 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient { - return NewGitHubReleaseClient(cfg.Update.ProxyURL) + return NewGitHubReleaseClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError) } // ProvidePricingRemoteClient 创建定价数据远程客户端 // 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据 func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient { - return NewPricingRemoteClient(cfg.Update.ProxyURL) + return NewPricingRemoteClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError) } // ProvideSessionLimitCache 创建会话限制缓存 @@ -53,12 +53,14 @@ var ProviderSet = wire.NewSet( NewAPIKeyRepository, NewGroupRepository, NewAccountRepository, + NewSoraAccountRepository, // Sora 账号扩展表仓储 NewProxyRepository, NewRedeemCodeRepository, NewPromoCodeRepository, NewAnnouncementRepository, NewAnnouncementReadRepository, NewUsageLogRepository, + NewIdempotencyRepository, NewUsageCleanupRepository, NewDashboardAggregationRepository, NewSettingRepository, @@ -77,6 +79,8 @@ var ProviderSet = wire.NewSet( NewTimeoutCounterCache, ProvideConcurrencyCache, ProvideSessionLimitCache, + NewRPMCache, + NewUserMsgQueueCache, NewDashboardCache, NewEmailCache, NewIdentityCache, @@ -104,6 +108,7 @@ var ProviderSet = wire.NewSet( NewOpenAIOAuthClient, NewGeminiOAuthClient, NewGeminiCliCodeAssistClient, + NewGeminiDriveClient, ProvideEnt, ProvideSQLDB, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index efef0452..f15a2074 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -83,6 +83,7 @@ func TestAPIContracts(t *testing.T) { "status": "active", "ip_whitelist": null, "ip_blacklist": null, + "last_used_at": null, "quota": 0, "quota_used": 0, "expires_at": null, @@ -122,6 +123,7 @@ func TestAPIContracts(t *testing.T) { "status": "active", "ip_whitelist": null, "ip_blacklist": null, + "last_used_at": null, "quota": 0, "quota_used": 0, "expires_at": null, @@ -184,7 +186,12 @@ func TestAPIContracts(t *testing.T) { "image_price_1k": null, "image_price_2k": null, "image_price_4k": null, - "claude_code_only": false, + "sora_image_price_360": null, + "sora_image_price_540": null, + "sora_storage_quota_bytes": 0, + "sora_video_price_per_request": null, + "sora_video_price_per_request_hd": null, + "claude_code_only": false, "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, "created_at": "2025-01-02T03:04:05Z", @@ -378,10 +385,12 @@ func TestAPIContracts(t *testing.T) { "user_id": 1, "api_key_id": 100, "account_id": 200, - "request_id": "req_123", - "model": "claude-3", - "group_id": null, - "subscription_id": null, + "request_id": "req_123", + "model": "claude-3", + "request_type": "stream", + "openai_ws_mode": false, + "group_id": null, + "subscription_id": null, "input_tokens": 10, "output_tokens": 20, "cache_creation_tokens": 1, @@ -401,6 +410,8 @@ func TestAPIContracts(t *testing.T) { "first_token_ms": 50, "image_count": 0, "image_size": null, + "media_type": null, + "cache_ttl_overridden": false, "created_at": "2025-01-02T03:04:05Z", "user_agent": null } @@ -488,18 +499,22 @@ func TestAPIContracts(t *testing.T) { "doc_url": "https://docs.example.com", "default_concurrency": 5, "default_balance": 1.25, + "default_subscriptions": [], "enable_model_fallback": false, "fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_antigravity": "gemini-2.5-pro", "fallback_model_gemini": "gemini-2.5-pro", - "fallback_model_openai": "gpt-4o", - "enable_identity_patch": true, - "identity_patch_prompt": "", - "invitation_code_enabled": false, - "home_content": "", + "fallback_model_openai": "gpt-4o", + "enable_identity_patch": true, + "identity_patch_prompt": "", + "sora_client_enabled": false, + "invitation_code_enabled": false, + "home_content": "", "hide_ccs_import_button": false, "purchase_subscription_enabled": false, - "purchase_subscription_url": "" + "purchase_subscription_url": "", + "min_claude_code_version": "", + "custom_menu_items": [] } }`, }, @@ -592,13 +607,13 @@ func newContractDeps(t *testing.T) *contractDeps { RunMode: config.RunModeStandard, } - userService := service.NewUserService(userRepo, nil) + userService := service.NewUserService(userRepo, nil, nil) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) - subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil) + subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil) @@ -607,12 +622,12 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) - adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) - adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil) + adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ @@ -767,6 +782,10 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID return 0, errors.New("not implemented") } +func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + return errors.New("not implemented") +} + func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { return errors.New("not implemented") } @@ -896,6 +915,10 @@ func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int return nil, errors.New("not implemented") } +func (stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error { + return nil +} + type stubAccountRepo struct { bulkUpdateIDs []int64 } @@ -920,6 +943,10 @@ func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID st return nil, errors.New("not implemented") } +func (s *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return errors.New("not implemented") } @@ -932,7 +959,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination return nil, nil, errors.New("not implemented") } -func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { +func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } @@ -1004,10 +1031,6 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt return errors.New("not implemented") } -func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { - return errors.New("not implemented") -} - func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { return errors.New("not implemented") } @@ -1049,6 +1072,10 @@ func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s return int64(len(ids)), nil } +func (s *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { + return nil, errors.New("not implemented") +} + type stubProxyRepo struct{} func (stubProxyRepo) Create(ctx context.Context, proxy *service.Proxy) error { @@ -1457,6 +1484,20 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + key, ok := r.byID[id] + if !ok { + return service.ErrAPIKeyNotFound + } + ts := usedAt + key.LastUsedAt = &ts + key.UpdatedAt = usedAt + clone := *key + r.byID[id] = &clone + r.byKey[clone.Key] = &clone + return nil +} + type stubUsageLogRepo struct { userLogs map[int64][]service.UsageLog } @@ -1525,11 +1566,15 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { +func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { +func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { return nil, errors.New("not implemented") } @@ -1602,11 +1647,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { +func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { +func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index d2d8ed40..a8034e98 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -51,6 +51,9 @@ func ProvideRouter( if err := r.SetTrustedProxies(nil); err != nil { log.Printf("Failed to disable trusted proxies: %v", err) } + if cfg.Server.Mode == "release" { + log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled") + } } return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) diff --git a/backend/internal/server/middleware/admin_auth.go b/backend/internal/server/middleware/admin_auth.go index 8f30107c..6f294ff0 100644 --- a/backend/internal/server/middleware/admin_auth.go +++ b/backend/internal/server/middleware/admin_auth.go @@ -58,8 +58,13 @@ func adminAuth( authHeader := c.GetHeader("Authorization") if authHeader != "" { parts := strings.SplitN(authHeader, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" { - if !validateJWTForAdmin(c, parts[1], authService, userService) { + if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { + token := strings.TrimSpace(parts[1]) + if token == "" { + AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required") + return + } + if !validateJWTForAdmin(c, token, authService, userService) { return } c.Next() @@ -176,6 +181,12 @@ func validateJWTForAdmin( return false } + // 校验 TokenVersion,确保管理员改密后旧 token 失效 + if claims.TokenVersion != user.TokenVersion { + AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)") + return false + } + // 检查管理员权限 if !user.IsAdmin() { AbortWithError(c, 403, "FORBIDDEN", "Admin access required") diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go new file mode 100644 index 00000000..033a5b77 --- /dev/null +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -0,0 +1,198 @@ +//go:build unit + +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} + authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + + admin := &service.User{ + ID: 1, + Email: "admin@example.com", + Role: service.RoleAdmin, + Status: service.StatusActive, + TokenVersion: 2, + Concurrency: 1, + } + + userRepo := &stubUserRepo{ + getByID: func(ctx context.Context, id int64) (*service.User, error) { + if id != admin.ID { + return nil, service.ErrUserNotFound + } + clone := *admin + return &clone, nil + }, + } + userService := service.NewUserService(userRepo, nil, nil) + + router := gin.New() + router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + t.Run("token_version_mismatch_rejected", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion - 1, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Contains(t, w.Body.String(), "TOKEN_REVOKED") + }) + + t.Run("token_version_match_allows", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion - 1, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Contains(t, w.Body.String(), "TOKEN_REVOKED") + }) + + t.Run("websocket_token_version_match_allows", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) +} + +type stubUserRepo struct { + getByID func(ctx context.Context, id int64) (*service.User, error) +} + +func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error { + panic("unexpected Create call") +} + +func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) { + if s.getByID == nil { + panic("GetByID not stubbed") + } + return s.getByID(ctx, id) +} + +func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) { + panic("unexpected GetByEmail call") +} + +func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error { + panic("unexpected Update call") +} + +func (s *stubUserRepo) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error { + panic("unexpected UpdateBalance call") +} + +func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error { + panic("unexpected DeductBalance call") +} + +func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error { + panic("unexpected UpdateConcurrency call") +} + +func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + panic("unexpected ExistsByEmail call") +} + +func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected RemoveGroupFromAllowedGroups call") +} + +func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + panic("unexpected AddGroupToAllowedGroups call") +} + +func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { + panic("unexpected UpdateTotpSecret call") +} + +func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error { + panic("unexpected EnableTotp call") +} + +func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error { + panic("unexpected DisableTotp call") +} diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 2f739357..19f97239 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -3,7 +3,6 @@ package middleware import ( "context" "errors" - "log" "strings" "github.com/Wei-Shaw/sub2api/internal/config" @@ -36,8 +35,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti if authHeader != "" { // 验证Bearer scheme parts := strings.SplitN(authHeader, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" { - apiKeyString = parts[1] + if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { + apiKeyString = strings.TrimSpace(parts[1]) } } @@ -97,8 +96,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti // 检查 IP 限制(白名单/黑名单) // 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制 if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { - clientIP := ip.GetClientIP(c) - allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist) + clientIP := ip.GetTrustedClientIP(c) + allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist) if !allowed { AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") return @@ -126,6 +125,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) setGroupContext(c, apiKey.Group) + _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) c.Next() return } @@ -134,7 +134,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() if isSubscriptionType && subscriptionService != nil { - // 订阅模式:验证订阅 + // 订阅模式:获取订阅(L1 缓存 + singleflight) subscription, err := subscriptionService.GetActiveSubscription( c.Request.Context(), apiKey.User.ID, @@ -145,30 +145,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } - // 验证订阅状态(是否过期、暂停等) - if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil { - AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error()) - return - } - - // 激活滑动窗口(首次使用时) - if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil { - log.Printf("Failed to activate subscription windows: %v", err) - } - - // 检查并重置过期窗口 - if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil { - log.Printf("Failed to reset subscription windows: %v", err) - } - - // 预检查用量限制(使用0作为额外费用进行预检查) - if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil { - AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error()) + // 合并验证 + 限额检查(纯内存操作) + needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group) + if err != nil { + code := "SUBSCRIPTION_INVALID" + status := 403 + if errors.Is(err, service.ErrDailyLimitExceeded) || + errors.Is(err, service.ErrWeeklyLimitExceeded) || + errors.Is(err, service.ErrMonthlyLimitExceeded) { + code = "USAGE_LIMIT_EXCEEDED" + status = 429 + } + AbortWithError(c, status, code, err.Error()) return } // 将订阅信息存入上下文 c.Set(string(ContextKeySubscription), subscription) + + // 窗口维护异步化(不阻塞请求) + // 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race + if needsMaintenance { + maintenanceCopy := *subscription + subscriptionService.DoWindowMaintenance(&maintenanceCopy) + } } else { // 余额模式:检查用户余额 if apiKey.User.Balance <= 0 { @@ -185,6 +185,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) setGroupContext(c, apiKey.Group) + _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) c.Next() } diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index 38fbe38b..84d93edc 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -64,6 +64,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) setGroupContext(c, apiKey.Group) + _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) c.Next() return } @@ -79,17 +80,25 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs abortWithGoogleError(c, 403, "No active subscription found for this group") return } - if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil { - abortWithGoogleError(c, 403, err.Error()) - return - } - _ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription) - _ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription) - if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil { - abortWithGoogleError(c, 429, err.Error()) + + needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group) + if err != nil { + status := 403 + if errors.Is(err, service.ErrDailyLimitExceeded) || + errors.Is(err, service.ErrWeeklyLimitExceeded) || + errors.Is(err, service.ErrMonthlyLimitExceeded) { + status = 429 + } + abortWithGoogleError(c, status, err.Error()) return } + c.Set(string(ContextKeySubscription), subscription) + + if needsMaintenance { + maintenanceCopy := *subscription + subscriptionService.DoWindowMaintenance(&maintenanceCopy) + } } else { if apiKey.User.Balance <= 0 { abortWithGoogleError(c, 403, "Insufficient account balance") @@ -104,6 +113,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) setGroupContext(c, apiKey.Group) + _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) c.Next() } } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 38b93cb2..2124c86c 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" @@ -18,7 +19,17 @@ import ( ) type fakeAPIKeyRepo struct { - getByKey func(ctx context.Context, key string) (*service.APIKey, error) + getByKey func(ctx context.Context, key string) (*service.APIKey, error) + updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error +} + +type fakeGoogleSubscriptionRepo struct { + getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) + updateStatus func(ctx context.Context, subscriptionID int64, status string) error + activateWindow func(ctx context.Context, id int64, start time.Time) error + resetDaily func(ctx context.Context, id int64, start time.Time) error + resetWeekly func(ctx context.Context, id int64, start time.Time) error + resetMonthly func(ctx context.Context, id int64, start time.Time) error } func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { @@ -78,6 +89,91 @@ func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([ func (f fakeAPIKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { return 0, errors.New("not implemented") } +func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + if f.updateLastUsed != nil { + return f.updateLastUsed(ctx, id, usedAt) + } + return nil +} + +func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + if f.getActive != nil { + return f.getActive(ctx, userID, groupID) + } + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { + return false, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { + if f.updateStatus != nil { + return f.updateStatus(ctx, subscriptionID, status) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error { + if f.activateWindow != nil { + return f.activateWindow(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, start time.Time) error { + if f.resetDaily != nil { + return f.resetDaily(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, start time.Time) error { + if f.resetWeekly != nil { + return f.resetWeekly(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, start time.Time) error { + if f.resetMonthly != nil { + return f.resetMonthly(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { + return 0, errors.New("not implemented") +} type googleErrorResponse struct { Error struct { @@ -356,3 +452,226 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { require.Equal(t, "Insufficient account balance", resp.Error.Message) require.Equal(t, "PERMISSION_DENIED", resp.Error.Status) } + +func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedOnSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 11, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 201, + UserID: user.ID, + Key: "google-touch-ok", + Status: service.StatusActive, + User: user, + } + + var touchedID int64 + var touchedAt time.Time + r := gin.New() + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchedID = id + touchedAt = usedAt + return nil + }, + }) + cfg := &config.Config{RunMode: config.RunModeSimple} + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("x-goog-api-key", apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, apiKey.ID, touchedID) + require.False(t, touchedAt.IsZero()) +} + +func TestApiKeyAuthWithSubscriptionGoogle_TouchFailureDoesNotBlock(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 12, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 202, + UserID: user.ID, + Key: "google-touch-fail", + Status: service.StatusActive, + User: user, + } + + touchCalls := 0 + r := gin.New() + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchCalls++ + return errors.New("write failed") + }, + }) + cfg := &config.Config{RunMode: config.RunModeSimple} + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("x-goog-api-key", apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 1, touchCalls) +} + +func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 13, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 203, + UserID: user.ID, + Key: "google-touch-standard", + Status: service.StatusActive, + User: user, + } + + touchCalls := 0 + r := gin.New() + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchCalls++ + return nil + }, + }) + cfg := &config.Config{RunMode: config.RunModeStandard} + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("Authorization", "Bearer "+apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 1, touchCalls) +} + +func TestApiKeyAuthWithSubscriptionGoogle_SubscriptionLimitExceededReturns429(t *testing.T) { + gin.SetMode(gin.TestMode) + + limit := 1.0 + group := &service.Group{ + ID: 77, + Name: "gemini-sub", + Status: service.StatusActive, + Platform: service.PlatformGemini, + Hydrated: true, + SubscriptionType: service.SubscriptionTypeSubscription, + DailyLimitUSD: &limit, + } + user := &service.User{ + ID: 999, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 501, + UserID: user.ID, + Key: "google-sub-limit", + Status: service.StatusActive, + User: user, + Group: group, + } + apiKey.GroupID = &group.ID + + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + }) + + now := time.Now() + sub := &service.UserSubscription{ + ID: 601, + UserID: user.ID, + GroupID: group.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: now.Add(24 * time.Hour), + DailyWindowStart: &now, + DailyUsageUSD: 10, + } + subscriptionService := service.NewSubscriptionService(nil, fakeGoogleSubscriptionRepo{ + getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + if userID != user.ID || groupID != group.ID { + return nil, service.ErrSubscriptionNotFound + } + clone := *sub + return &clone, nil + }, + updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil }, + activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + }, nil, nil, &config.Config{RunMode: config.RunModeStandard}) + + r := gin.New() + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, &config.Config{RunMode: config.RunModeStandard})) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("x-goog-api-key", apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusTooManyRequests, rec.Code) + var resp googleErrorResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, http.StatusTooManyRequests, resp.Error.Code) + require.Equal(t, "RESOURCE_EXHAUSTED", resp.Error.Status) + require.Contains(t, resp.Error.Message, "daily usage limit exceeded") +} diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 9d514818..0d331761 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -57,10 +57,41 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { }, } - t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { - cfg := &config.Config{RunMode: config.RunModeSimple} + t.Run("standard_mode_needs_maintenance_does_not_block_request", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeStandard} + cfg.SubscriptionMaintenance.WorkerCount = 1 + cfg.SubscriptionMaintenance.QueueSize = 1 + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) - subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) + + past := time.Now().Add(-48 * time.Hour) + sub := &service.UserSubscription{ + ID: 55, + UserID: user.ID, + GroupID: group.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + DailyWindowStart: &past, + DailyUsageUSD: 0, + } + maintenanceCalled := make(chan struct{}, 1) + subscriptionRepo := &stubUserSubscriptionRepo{ + getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + clone := *sub + return &clone, nil + }, + updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil }, + activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetDaily: func(ctx context.Context, id int64, start time.Time) error { + maintenanceCalled <- struct{}{} + return nil + }, + resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + } + subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg) + t.Cleanup(subscriptionService.Stop) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) w := httptest.NewRecorder() @@ -68,6 +99,40 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { req.Header.Set("x-api-key", apiKey.Key) router.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + select { + case <-maintenanceCalled: + // ok + case <-time.After(time.Second): + t.Fatalf("expected maintenance to be scheduled") + } + }) + + t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("simple_mode_accepts_lowercase_bearer", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Authorization", "bearer "+apiKey.Key) + router.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) }) @@ -99,7 +164,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, } - subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil) + subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) w := httptest.NewRecorder() @@ -235,6 +300,198 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + IPWhitelist: []string{"1.2.3.4"}, + } + + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := gin.New() + require.NoError(t, router.SetTrustedProxies(nil)) + router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.RemoteAddr = "9.9.9.9:12345" + req.Header.Set("x-api-key", apiKey.Key) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + req.Header.Set("X-Real-IP", "1.2.3.4") + req.Header.Set("CF-Connecting-IP", "1.2.3.4") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusForbidden, w.Code) + require.Contains(t, w.Body.String(), "ACCESS_DENIED") +} + +func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "touch-ok", + Status: service.StatusActive, + User: user, + } + + var touchedID int64 + var touchedAt time.Time + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchedID = id + touchedAt = usedAt + return nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, nil, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, apiKey.ID, touchedID) + require.False(t, touchedAt.IsZero(), "expected touch timestamp") +} + +func TestAPIKeyAuthTouchLastUsedFailureDoesNotBlock(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 8, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 101, + UserID: user.ID, + Key: "touch-fail", + Status: service.StatusActive, + User: user, + } + + touchCalls := 0 + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchCalls++ + return errors.New("db unavailable") + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, nil, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "touch failure should not block request") + require.Equal(t, 1, touchCalls) +} + +func TestAPIKeyAuthTouchesLastUsedInStandardMode(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 9, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 102, + UserID: user.ID, + Key: "touch-standard", + Status: service.StatusActive, + User: user, + } + + touchCalls := 0 + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchCalls++ + return nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeStandard} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, nil, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, 1, touchCalls) +} + func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) @@ -245,7 +502,8 @@ func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService } type stubApiKeyRepo struct { - getByKey func(ctx context.Context, key string) (*service.APIKey, error) + getByKey func(ctx context.Context, key string) (*service.APIKey, error) + updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error } func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error { @@ -323,6 +581,13 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + if r.updateLastUsed != nil { + return r.updateLastUsed(ctx, id, usedAt) + } + return nil +} + type stubUserSubscriptionRepo struct { getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) updateStatus func(ctx context.Context, subscriptionID int64, status string) error diff --git a/backend/internal/server/middleware/client_request_id.go b/backend/internal/server/middleware/client_request_id.go index d22b6cc5..6838d6af 100644 --- a/backend/internal/server/middleware/client_request_id.go +++ b/backend/internal/server/middleware/client_request_id.go @@ -2,10 +2,13 @@ package middleware import ( "context" + "strings" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" "github.com/google/uuid" + "go.uber.org/zap" ) // ClientRequestID ensures every request has a unique client_request_id in request.Context(). @@ -24,7 +27,10 @@ func ClientRequestID() gin.HandlerFunc { } id := uuid.New().String() - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)) + ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id) + requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id))) + ctx = logger.IntoContext(ctx, requestLogger) + c.Request = c.Request.WithContext(ctx) c.Next() } } diff --git a/backend/internal/server/middleware/cors.go b/backend/internal/server/middleware/cors.go index 7d82f183..03d5d025 100644 --- a/backend/internal/server/middleware/cors.go +++ b/backend/internal/server/middleware/cors.go @@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { } allowedSet[origin] = struct{}{} } + allowHeaders := []string{ + "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", + "accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key", + } + // OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。 + openAIProperties := []string{ + "lang", "package-version", "os", "arch", "retry-count", "runtime", + "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout", + } + for _, prop := range openAIProperties { + allowHeaders = append(allowHeaders, "x-stainless-"+prop) + } + allowHeadersValue := strings.Join(allowHeaders, ", ") return func(c *gin.Context) { origin := strings.TrimSpace(c.GetHeader("Origin")) @@ -68,11 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { if allowCredentials { c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") } + c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeadersValue) + c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") + c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag") + c.Writer.Header().Set("Access-Control-Max-Age", "86400") } - - c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") - c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") - // 处理预检请求 if c.Request.Method == http.MethodOptions { if originAllowed { diff --git a/backend/internal/server/middleware/cors_test.go b/backend/internal/server/middleware/cors_test.go new file mode 100644 index 00000000..6d0bea36 --- /dev/null +++ b/backend/internal/server/middleware/cors_test.go @@ -0,0 +1,308 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func init() { + // cors_test 与 security_headers_test 在同一个包,但 init 是幂等的 + gin.SetMode(gin.TestMode) +} + +// --- Task 8.2: 验证 CORS 条件化头部 --- + +func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + tests := []struct { + name string + method string + origin string + }{ + { + name: "preflight_disallowed_origin", + method: http.MethodOptions, + origin: "https://evil.example.com", + }, + { + name: "get_disallowed_origin", + method: http.MethodGet, + origin: "https://evil.example.com", + }, + { + name: "post_disallowed_origin", + method: http.MethodPost, + origin: "https://attacker.example.com", + }, + { + name: "preflight_no_origin", + method: http.MethodOptions, + origin: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(tt.method, "/", nil) + if tt.origin != "" { + c.Request.Header.Set("Origin", tt.origin) + } + + middleware(c) + + // 不应设置 Allow-Headers、Allow-Methods 和 Max-Age + assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"), + "不允许的 origin 不应收到 Allow-Headers") + assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"), + "不允许的 origin 不应收到 Allow-Methods") + assert.Empty(t, w.Header().Get("Access-Control-Max-Age"), + "不允许的 origin 不应收到 Max-Age") + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"), + "不允许的 origin 不应收到 Allow-Origin") + }) + } +} + +func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + tests := []struct { + name string + method string + }{ + {name: "preflight_OPTIONS", method: http.MethodOptions}, + {name: "normal_GET", method: http.MethodGet}, + {name: "normal_POST", method: http.MethodPost}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(tt.method, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + // 应设置 Allow-Headers、Allow-Methods 和 Max-Age + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"), + "允许的 origin 应收到 Allow-Headers") + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"), + "允许的 origin 应收到 Allow-Methods") + assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"), + "允许的 origin 应收到 Max-Age=86400") + assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"), + "允许的 origin 应收到 Allow-Origin") + }) + } +} + +func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodOptions, "/", nil) + c.Request.Header.Set("Origin", "https://evil.example.com") + + middleware(c) + + assert.Equal(t, http.StatusForbidden, w.Code, + "不允许的 origin 的 preflight 请求应返回 403") +} + +func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodOptions, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + assert.Equal(t, http.StatusNoContent, w.Code, + "允许的 origin 的 preflight 请求应返回 204") +} + +func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://any-origin.example.com") + + middleware(c) + + assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"), + "通配符配置应返回 *") + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"), + "通配符 origin 应设置 Allow-Headers") + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"), + "通配符 origin 应设置 Allow-Methods") +} + +func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: true, + } + middleware := CORS(cfg) + + t.Run("allowed_origin_gets_credentials", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"), + "允许的 origin 且开启 credentials 应设置 Allow-Credentials") + }) + + t.Run("disallowed_origin_no_credentials", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://evil.example.com") + + middleware(c) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"), + "不允许的 origin 不应收到 Allow-Credentials") + }) +} + +func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: true, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://any.example.com") + + middleware(c) + + // 通配符 + credentials 不兼容,credentials 应被禁用 + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"), + "通配符 origin 应禁用 Allow-Credentials") +} + +func TestCORS_MultipleAllowedOrigins(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{ + "https://app1.example.com", + "https://app2.example.com", + }, + AllowCredentials: false, + } + middleware := CORS(cfg) + + t.Run("first_origin_allowed", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://app1.example.com") + + middleware(c) + + assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) + + t.Run("second_origin_allowed", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://app2.example.com") + + middleware(c) + + assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) + + t.Run("unlisted_origin_rejected", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://app3.example.com") + + middleware(c) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) +} + +func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + assert.Contains(t, w.Header().Values("Vary"), "Origin", + "非通配符允许的 origin 应设置 Vary: Origin") +} + +func TestNormalizeOrigins(t *testing.T) { + tests := []struct { + name string + input []string + expect []string + }{ + {name: "nil_input", input: nil, expect: nil}, + {name: "empty_input", input: []string{}, expect: nil}, + {name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}}, + {name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeOrigins(tt.input) + assert.Equal(t, tt.expect, result) + }) + } +} diff --git a/backend/internal/server/middleware/jwt_auth.go b/backend/internal/server/middleware/jwt_auth.go index 9a89aab7..4aceb355 100644 --- a/backend/internal/server/middleware/jwt_auth.go +++ b/backend/internal/server/middleware/jwt_auth.go @@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService) // 验证Bearer scheme parts := strings.SplitN(authHeader, " ", 2) - if len(parts) != 2 || parts[0] != "Bearer" { + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'") return } - tokenString := parts[1] + tokenString := strings.TrimSpace(parts[1]) if tokenString == "" { AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty") return diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go new file mode 100644 index 00000000..f8839cfe --- /dev/null +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -0,0 +1,256 @@ +//go:build unit + +package middleware + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// stubJWTUserRepo 实现 UserRepository 的最小子集,仅支持 GetByID。 +type stubJWTUserRepo struct { + service.UserRepository + users map[int64]*service.User +} + +func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, error) { + u, ok := r.users[id] + if !ok { + return nil, errors.New("user not found") + } + return u, nil +} + +// newJWTTestEnv 创建 JWT 认证中间件测试环境。 +// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。 +func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!" + cfg.JWT.AccessTokenExpireMinutes = 60 + + userRepo := &stubJWTUserRepo{users: users} + authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + userSvc := service.NewUserService(userRepo, nil, nil) + mw := NewJWTAuthMiddleware(authSvc, userSvc) + + r := gin.New() + r.Use(gin.HandlerFunc(mw)) + r.GET("/protected", func(c *gin.Context) { + subject, _ := GetAuthSubjectFromContext(c) + role, _ := GetUserRoleFromContext(c) + c.JSON(http.StatusOK, gin.H{ + "user_id": subject.UserID, + "role": role, + }) + }) + return r, authSvc +} + +func TestJWTAuth_ValidToken(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + Concurrency: 5, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + + var body map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, float64(1), body["user_id"]) + require.Equal(t, "user", body["role"]) +} + +func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + Concurrency: 5, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) +} + +func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "UNAUTHORIZED", body.Code) +} + +func TestJWTAuth_InvalidHeaderFormat(t *testing.T) { + tests := []struct { + name string + header string + }{ + {"无Bearer前缀", "Token abc123"}, + {"缺少空格分隔", "Bearerabc123"}, + {"仅有单词", "abc123"}, + } + router, _ := newJWTTestEnv(nil) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", tt.header) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "INVALID_AUTH_HEADER", body.Code) + }) + } +} + +func TestJWTAuth_EmptyToken(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer ") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "EMPTY_TOKEN", body.Code) +} + +func TestJWTAuth_TamperedToken(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoxfQ.invalid_signature") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "INVALID_TOKEN", body.Code) +} + +func TestJWTAuth_UserNotFound(t *testing.T) { + // 使用 user ID=1 的 token,但 repo 中没有该用户 + fakeUser := &service.User{ + ID: 999, + Email: "ghost@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 1, + } + // 创建环境时不注入此用户,这样 GetByID 会失败 + router, authSvc := newJWTTestEnv(map[int64]*service.User{}) + + token, err := authSvc.GenerateToken(fakeUser) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "USER_NOT_FOUND", body.Code) +} + +func TestJWTAuth_UserInactive(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "disabled@example.com", + Role: "user", + Status: service.StatusDisabled, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "USER_INACTIVE", body.Code) +} + +func TestJWTAuth_TokenVersionMismatch(t *testing.T) { + // Token 生成时 TokenVersion=1,但数据库中用户已更新为 TokenVersion=2(密码修改) + userForToken := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 1, + } + userInDB := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 2, // 密码修改后版本递增 + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: userInDB}) + + token, err := authSvc.GenerateToken(userForToken) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "TOKEN_REVOKED", body.Code) +} diff --git a/backend/internal/server/middleware/logger.go b/backend/internal/server/middleware/logger.go index 842efda9..b14a3a21 100644 --- a/backend/internal/server/middleware/logger.go +++ b/backend/internal/server/middleware/logger.go @@ -1,10 +1,12 @@ package middleware import ( - "log" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" + "go.uber.org/zap" ) // Logger 请求日志中间件 @@ -13,44 +15,52 @@ func Logger() gin.HandlerFunc { // 开始时间 startTime := time.Now() - // 处理请求 - c.Next() - - // 结束时间 - endTime := time.Now() - - // 执行时间 - latency := endTime.Sub(startTime) - - // 请求方法 - method := c.Request.Method - // 请求路径 path := c.Request.URL.Path - // 状态码 + // 处理请求 + c.Next() + + // 跳过健康检查等高频探针路径的日志 + if path == "/health" || path == "/setup/status" { + return + } + + endTime := time.Now() + latency := endTime.Sub(startTime) + + method := c.Request.Method statusCode := c.Writer.Status() - - // 客户端IP clientIP := c.ClientIP() - - // 协议版本 protocol := c.Request.Proto + accountID, hasAccountID := c.Request.Context().Value(ctxkey.AccountID).(int64) + platform, _ := c.Request.Context().Value(ctxkey.Platform).(string) + model, _ := c.Request.Context().Value(ctxkey.Model).(string) - // 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径 - log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s", - endTime.Format("2006/01/02 - 15:04:05"), - statusCode, - latency, - clientIP, - protocol, - method, - path, - ) + fields := []zap.Field{ + zap.String("component", "http.access"), + zap.Int("status_code", statusCode), + zap.Int64("latency_ms", latency.Milliseconds()), + zap.String("client_ip", clientIP), + zap.String("protocol", protocol), + zap.String("method", method), + zap.String("path", path), + } + if hasAccountID && accountID > 0 { + fields = append(fields, zap.Int64("account_id", accountID)) + } + if platform != "" { + fields = append(fields, zap.String("platform", platform)) + } + if model != "" { + fields = append(fields, zap.String("model", model)) + } + + l := logger.FromContext(c.Request.Context()).With(fields...) + l.Info("http request completed", zap.Time("completed_at", endTime)) - // 如果有错误,额外记录错误信息 if len(c.Errors) > 0 { - log.Printf("[GIN] Errors: %v", c.Errors.String()) + l.Warn("http request contains gin errors", zap.String("errors", c.Errors.String())) } } } diff --git a/backend/internal/server/middleware/misc_coverage_test.go b/backend/internal/server/middleware/misc_coverage_test.go new file mode 100644 index 00000000..c0adfc4d --- /dev/null +++ b/backend/internal/server/middleware/misc_coverage_test.go @@ -0,0 +1,126 @@ +//go:build unit + +package middleware + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestClientRequestID_GeneratesWhenMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ClientRequestID()) + r.GET("/t", func(c *gin.Context) { + v := c.Request.Context().Value(ctxkey.ClientRequestID) + require.NotNil(t, v) + id, ok := v.(string) + require.True(t, ok) + require.NotEmpty(t, id) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestClientRequestID_PreservesExisting(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ClientRequestID()) + r.GET("/t", func(c *gin.Context) { + id, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string) + require.True(t, ok) + require.Equal(t, "keep", id) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "keep")) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestRequestBodyLimit_LimitsBody(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(RequestBodyLimit(4)) + r.POST("/t", func(c *gin.Context) { + _, err := io.ReadAll(c.Request.Body) + require.Error(t, err) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/t", bytes.NewBufferString("12345")) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestForcePlatform_SetsContextAndGinValue(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ForcePlatform("anthropic")) + r.GET("/t", func(c *gin.Context) { + require.True(t, HasForcePlatform(c)) + v, ok := GetForcePlatformFromContext(c) + require.True(t, ok) + require.Equal(t, "anthropic", v) + + ctxV := c.Request.Context().Value(ctxkey.ForcePlatform) + require.Equal(t, "anthropic", ctxV) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestAuthSubjectHelpers_RoundTrip(t *testing.T) { + c := &gin.Context{} + c.Set(string(ContextKeyUser), AuthSubject{UserID: 1, Concurrency: 2}) + c.Set(string(ContextKeyUserRole), "admin") + + sub, ok := GetAuthSubjectFromContext(c) + require.True(t, ok) + require.Equal(t, int64(1), sub.UserID) + require.Equal(t, 2, sub.Concurrency) + + role, ok := GetUserRoleFromContext(c) + require.True(t, ok) + require.Equal(t, "admin", role) +} + +func TestAPIKeyAndSubscriptionFromContext(t *testing.T) { + c := &gin.Context{} + + key := &service.APIKey{ID: 1} + c.Set(string(ContextKeyAPIKey), key) + gotKey, ok := GetAPIKeyFromContext(c) + require.True(t, ok) + require.Equal(t, int64(1), gotKey.ID) + + sub := &service.UserSubscription{ID: 2} + c.Set(string(ContextKeySubscription), sub) + gotSub, ok := GetSubscriptionFromContext(c) + require.True(t, ok) + require.Equal(t, int64(2), gotSub.ID) +} diff --git a/backend/internal/server/middleware/recovery_test.go b/backend/internal/server/middleware/recovery_test.go index 439f44cb..33e71d51 100644 --- a/backend/internal/server/middleware/recovery_test.go +++ b/backend/internal/server/middleware/recovery_test.go @@ -3,6 +3,7 @@ package middleware import ( + "bytes" "encoding/json" "net/http" "net/http/httptest" @@ -14,6 +15,34 @@ import ( "github.com/stretchr/testify/require" ) +func TestRecovery_PanicLogContainsInfo(t *testing.T) { + gin.SetMode(gin.TestMode) + + // 临时替换 DefaultErrorWriter 以捕获日志输出 + var buf bytes.Buffer + originalWriter := gin.DefaultErrorWriter + gin.DefaultErrorWriter = &buf + t.Cleanup(func() { + gin.DefaultErrorWriter = originalWriter + }) + + r := gin.New() + r.Use(Recovery()) + r.GET("/panic", func(c *gin.Context) { + panic("custom panic message for test") + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/panic", nil) + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusInternalServerError, w.Code) + + logOutput := buf.String() + require.Contains(t, logOutput, "custom panic message for test", "日志应包含 panic 信息") + require.Contains(t, logOutput, "recovery_test.go", "日志应包含堆栈跟踪文件名") +} + func TestRecovery(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/server/middleware/request_access_logger_test.go b/backend/internal/server/middleware/request_access_logger_test.go new file mode 100644 index 00000000..fec3ed22 --- /dev/null +++ b/backend/internal/server/middleware/request_access_logger_test.go @@ -0,0 +1,228 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" +) + +type testLogSink struct { + mu sync.Mutex + events []*logger.LogEvent +} + +func (s *testLogSink) WriteLogEvent(event *logger.LogEvent) { + s.mu.Lock() + defer s.mu.Unlock() + s.events = append(s.events, event) +} + +func (s *testLogSink) list() []*logger.LogEvent { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]*logger.LogEvent, len(s.events)) + copy(out, s.events) + return out +} + +func initMiddlewareTestLogger(t *testing.T) *testLogSink { + return initMiddlewareTestLoggerWithLevel(t, "debug") +} + +func initMiddlewareTestLoggerWithLevel(t *testing.T, level string) *testLogSink { + t.Helper() + level = strings.TrimSpace(level) + if level == "" { + level = "debug" + } + if err := logger.Init(logger.InitOptions{ + Level: level, + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: false, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + sink := &testLogSink{} + logger.SetSink(sink) + t.Cleanup(func() { + logger.SetSink(nil) + }) + return sink +} + +func TestRequestLogger_GenerateAndPropagateRequestID(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(RequestLogger()) + r.GET("/t", func(c *gin.Context) { + reqID, ok := c.Request.Context().Value(ctxkey.RequestID).(string) + if !ok || reqID == "" { + t.Fatalf("request_id missing in context") + } + if got := c.Writer.Header().Get(requestIDHeader); got != reqID { + t.Fatalf("response header request_id mismatch, header=%q ctx=%q", got, reqID) + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d", w.Code) + } + if w.Header().Get(requestIDHeader) == "" { + t.Fatalf("X-Request-ID should be set") + } +} + +func TestRequestLogger_KeepIncomingRequestID(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(RequestLogger()) + r.GET("/t", func(c *gin.Context) { + reqID, _ := c.Request.Context().Value(ctxkey.RequestID).(string) + if reqID != "rid-fixed" { + t.Fatalf("request_id=%q, want rid-fixed", reqID) + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set(requestIDHeader, "rid-fixed") + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d", w.Code) + } + if got := w.Header().Get(requestIDHeader); got != "rid-fixed" { + t.Fatalf("header=%q, want rid-fixed", got) + } +} + +func TestLogger_AccessLogIncludesCoreFields(t *testing.T) { + gin.SetMode(gin.TestMode) + sink := initMiddlewareTestLogger(t) + + r := gin.New() + r.Use(Logger()) + r.Use(func(c *gin.Context) { + ctx := c.Request.Context() + ctx = context.WithValue(ctx, ctxkey.AccountID, int64(101)) + ctx = context.WithValue(ctx, ctxkey.Platform, "openai") + ctx = context.WithValue(ctx, ctxkey.Model, "gpt-5") + c.Request = c.Request.WithContext(ctx) + c.Next() + }) + r.GET("/api/test", func(c *gin.Context) { + c.Status(http.StatusCreated) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("status=%d", w.Code) + } + + events := sink.list() + if len(events) == 0 { + t.Fatalf("expected at least one log event") + } + found := false + for _, event := range events { + if event == nil || event.Message != "http request completed" { + continue + } + found = true + switch v := event.Fields["status_code"].(type) { + case int: + if v != http.StatusCreated { + t.Fatalf("status_code field mismatch: %v", v) + } + case int64: + if v != int64(http.StatusCreated) { + t.Fatalf("status_code field mismatch: %v", v) + } + default: + t.Fatalf("status_code type mismatch: %T", v) + } + switch v := event.Fields["account_id"].(type) { + case int64: + if v != 101 { + t.Fatalf("account_id field mismatch: %v", v) + } + case int: + if v != 101 { + t.Fatalf("account_id field mismatch: %v", v) + } + default: + t.Fatalf("account_id type mismatch: %T", v) + } + if event.Fields["platform"] != "openai" || event.Fields["model"] != "gpt-5" { + t.Fatalf("platform/model mismatch: %+v", event.Fields) + } + } + if !found { + t.Fatalf("access log event not found") + } +} + +func TestLogger_HealthPathSkipped(t *testing.T) { + gin.SetMode(gin.TestMode) + sink := initMiddlewareTestLogger(t) + + r := gin.New() + r.Use(Logger()) + r.GET("/health", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/health", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d", w.Code) + } + if len(sink.list()) != 0 { + t.Fatalf("health endpoint should not write access log") + } +} + +func TestLogger_AccessLogDroppedWhenLevelWarn(t *testing.T) { + gin.SetMode(gin.TestMode) + sink := initMiddlewareTestLoggerWithLevel(t, "warn") + + r := gin.New() + r.Use(RequestLogger()) + r.Use(Logger()) + r.GET("/api/test", func(c *gin.Context) { + c.Status(http.StatusCreated) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("status=%d", w.Code) + } + + events := sink.list() + for _, event := range events { + if event != nil && event.Message == "http request completed" { + t.Fatalf("access log should not be indexed when level=warn: %+v", event) + } + } +} diff --git a/backend/internal/server/middleware/request_logger.go b/backend/internal/server/middleware/request_logger.go new file mode 100644 index 00000000..0fb2feca --- /dev/null +++ b/backend/internal/server/middleware/request_logger.go @@ -0,0 +1,45 @@ +package middleware + +import ( + "context" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "go.uber.org/zap" +) + +const requestIDHeader = "X-Request-ID" + +// RequestLogger 在请求入口注入 request-scoped logger。 +func RequestLogger() gin.HandlerFunc { + return func(c *gin.Context) { + if c.Request == nil { + c.Next() + return + } + + requestID := strings.TrimSpace(c.GetHeader(requestIDHeader)) + if requestID == "" { + requestID = uuid.NewString() + } + c.Header(requestIDHeader, requestID) + + ctx := context.WithValue(c.Request.Context(), ctxkey.RequestID, requestID) + clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string) + + requestLogger := logger.With( + zap.String("component", "http"), + zap.String("request_id", requestID), + zap.String("client_request_id", strings.TrimSpace(clientRequestID)), + zap.String("path", c.Request.URL.Path), + zap.String("method", c.Request.Method), + ) + + ctx = logger.IntoContext(ctx, requestLogger) + c.Request = c.Request.WithContext(ctx) + c.Next() + } +} diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 9ce7f449..d9ec951e 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -3,6 +3,8 @@ package middleware import ( "crypto/rand" "encoding/base64" + "fmt" + "log" "strings" "github.com/Wei-Shaw/sub2api/internal/config" @@ -18,11 +20,14 @@ const ( CloudflareInsightsDomain = "https://static.cloudflareinsights.com" ) -// GenerateNonce generates a cryptographically secure random nonce -func GenerateNonce() string { +// GenerateNonce generates a cryptographically secure random nonce. +// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。 +func GenerateNonce() (string, error) { b := make([]byte, 16) - _, _ = rand.Read(b) - return base64.StdEncoding.EncodeToString(b) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate CSP nonce: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil } // GetNonceFromContext retrieves the CSP nonce from gin context @@ -36,7 +41,9 @@ func GetNonceFromContext(c *gin.Context) string { } // SecurityHeaders sets baseline security headers for all responses. -func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { +// getFrameSrcOrigins is an optional function that returns extra origins to inject into frame-src; +// pass nil to disable dynamic frame-src injection. +func SecurityHeaders(cfg config.CSPConfig, getFrameSrcOrigins func() []string) gin.HandlerFunc { policy := strings.TrimSpace(cfg.Policy) if policy == "" { policy = config.DefaultCSPPolicy @@ -46,23 +53,51 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { policy = enhanceCSPPolicy(policy) return func(c *gin.Context) { + finalPolicy := policy + if getFrameSrcOrigins != nil { + for _, origin := range getFrameSrcOrigins() { + if origin != "" { + finalPolicy = addToDirective(finalPolicy, "frame-src", origin) + } + } + } + c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Frame-Options", "DENY") c.Header("Referrer-Policy", "strict-origin-when-cross-origin") + if isAPIRoutePath(c) { + c.Next() + return + } if cfg.Enabled { // Generate nonce for this request - nonce := GenerateNonce() - c.Set(CSPNonceKey, nonce) - - // Replace nonce placeholder in policy - finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'") - c.Header("Content-Security-Policy", finalPolicy) + nonce, err := GenerateNonce() + if err != nil { + // crypto/rand 失败时降级为无 nonce 的 CSP 策略 + log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err) + c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'unsafe-inline'")) + } else { + c.Set(CSPNonceKey, nonce) + c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'nonce-"+nonce+"'")) + } } c.Next() } } +func isAPIRoutePath(c *gin.Context) bool { + if c == nil || c.Request == nil || c.Request.URL == nil { + return false + } + path := c.Request.URL.Path + return strings.HasPrefix(path, "/v1/") || + strings.HasPrefix(path, "/v1beta/") || + strings.HasPrefix(path, "/antigravity/") || + strings.HasPrefix(path, "/sora/") || + strings.HasPrefix(path, "/responses") +} + // enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain. // This allows the application to work correctly even if the config file has an older CSP policy. func enhanceCSPPolicy(policy string) string { diff --git a/backend/internal/server/middleware/security_headers_test.go b/backend/internal/server/middleware/security_headers_test.go index dc7a87d8..031385d0 100644 --- a/backend/internal/server/middleware/security_headers_test.go +++ b/backend/internal/server/middleware/security_headers_test.go @@ -19,7 +19,8 @@ func init() { func TestGenerateNonce(t *testing.T) { t.Run("generates_valid_base64_string", func(t *testing.T) { - nonce := GenerateNonce() + nonce, err := GenerateNonce() + require.NoError(t, err) // Should be valid base64 decoded, err := base64.StdEncoding.DecodeString(nonce) @@ -32,14 +33,16 @@ func TestGenerateNonce(t *testing.T) { t.Run("generates_unique_nonces", func(t *testing.T) { nonces := make(map[string]bool) for i := 0; i < 100; i++ { - nonce := GenerateNonce() + nonce, err := GenerateNonce() + require.NoError(t, err) assert.False(t, nonces[nonce], "nonce should be unique") nonces[nonce] = true } }) t.Run("nonce_has_expected_length", func(t *testing.T) { - nonce := GenerateNonce() + nonce, err := GenerateNonce() + require.NoError(t, err) // 16 bytes -> 24 chars in base64 (with padding) assert.Len(t, nonce, 24) }) @@ -81,7 +84,7 @@ func TestGetNonceFromContext(t *testing.T) { func TestSecurityHeaders(t *testing.T) { t.Run("sets_basic_security_headers", func(t *testing.T) { cfg := config.CSPConfig{Enabled: false} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -96,7 +99,7 @@ func TestSecurityHeaders(t *testing.T) { t.Run("csp_disabled_no_csp_header", func(t *testing.T) { cfg := config.CSPConfig{Enabled: false} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -112,7 +115,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "default-src 'self'", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -128,12 +131,32 @@ func TestSecurityHeaders(t *testing.T) { assert.Contains(t, csp, CloudflareInsightsDomain) }) + t.Run("api_route_skips_csp_nonce_generation", func(t *testing.T) { + cfg := config.CSPConfig{ + Enabled: true, + Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__", + } + middleware := SecurityHeaders(cfg, nil) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + middleware(c) + + assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options")) + assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options")) + assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy")) + assert.Empty(t, w.Header().Get("Content-Security-Policy")) + assert.Empty(t, GetNonceFromContext(c)) + }) + t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) { cfg := config.CSPConfig{ Enabled: true, Policy: "script-src 'self' __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -157,7 +180,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -176,7 +199,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: " \t\n ", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -194,7 +217,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -212,7 +235,7 @@ func TestSecurityHeaders(t *testing.T) { t.Run("calls_next_handler", func(t *testing.T) { cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) nextCalled := false router := gin.New() @@ -235,7 +258,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) nonces := make(map[string]bool) for i := 0; i < 10; i++ { @@ -344,7 +367,7 @@ func TestAddToDirective(t *testing.T) { // Benchmark tests func BenchmarkGenerateNonce(b *testing.B) { for i := 0; i < b.N; i++ { - GenerateNonce() + _, _ = GenerateNonce() } } @@ -353,7 +376,7 @@ func BenchmarkSecurityHeadersMiddleware(b *testing.B) { Enabled: true, Policy: "script-src 'self' __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index cf9015e4..430edcf8 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -1,7 +1,10 @@ package server import ( + "context" "log" + "sync/atomic" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" @@ -14,6 +17,8 @@ import ( "github.com/redis/go-redis/v9" ) +const frameSrcRefreshTimeout = 5 * time.Second + // SetupRouter 配置路由器中间件和路由 func SetupRouter( r *gin.Engine, @@ -28,10 +33,33 @@ func SetupRouter( cfg *config.Config, redisClient *redis.Client, ) *gin.Engine { + // 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src + var cachedFrameOrigins atomic.Pointer[[]string] + emptyOrigins := []string{} + cachedFrameOrigins.Store(&emptyOrigins) + + refreshFrameOrigins := func() { + ctx, cancel := context.WithTimeout(context.Background(), frameSrcRefreshTimeout) + defer cancel() + origins, err := settingService.GetFrameSrcOrigins(ctx) + if err != nil { + // 获取失败时保留已有缓存,避免 frame-src 被意外清空 + return + } + cachedFrameOrigins.Store(&origins) + } + refreshFrameOrigins() // 启动时初始化 + // 应用中间件 + r.Use(middleware2.RequestLogger()) r.Use(middleware2.Logger()) r.Use(middleware2.CORS(cfg.CORS)) - r.Use(middleware2.SecurityHeaders(cfg.Security.CSP)) + r.Use(middleware2.SecurityHeaders(cfg.Security.CSP, func() []string { + if p := cachedFrameOrigins.Load(); p != nil { + return *p + } + return nil + })) // Serve embedded frontend with settings injection if available if web.HasEmbeddedFrontend() { @@ -39,11 +67,17 @@ func SetupRouter( if err != nil { log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err) r.Use(web.ServeEmbeddedFrontend()) + settingService.SetOnUpdateCallback(refreshFrameOrigins) } else { - // Register cache invalidation callback - settingService.SetOnUpdateCallback(frontendServer.InvalidateCache) + // Register combined callback: invalidate HTML cache + refresh frame origins + settingService.SetOnUpdateCallback(func() { + frontendServer.InvalidateCache() + refreshFrameOrigins() + }) r.Use(frontendServer.Middleware()) } + } else { + settingService.SetOnUpdateCallback(refreshFrameOrigins) } // 注册路由 @@ -74,6 +108,7 @@ func registerRoutes( // 注册各模块路由 routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient) routes.RegisterUserRoutes(v1, h, jwtAuth) + routes.RegisterSoraClientRoutes(v1, h, jwtAuth) routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg) } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 14815262..c36c36a0 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -34,6 +34,8 @@ func RegisterAdminRoutes( // OpenAI OAuth registerOpenAIOAuthRoutes(admin, h) + // Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立) + registerSoraOAuthRoutes(admin, h) // Gemini OAuth registerGeminiOAuthRoutes(admin, h) @@ -53,6 +55,9 @@ func RegisterAdminRoutes( // 系统设置 registerSettingsRoutes(admin, h) + // 数据管理 + registerDataManagementRoutes(admin, h) + // 运维监控(Ops) registerOpsRoutes(admin, h) @@ -70,6 +75,16 @@ func RegisterAdminRoutes( // 错误透传规则管理 registerErrorPassthroughRoutes(admin, h) + + // API Key 管理 + registerAdminAPIKeyRoutes(admin, h) + } +} + +func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + apiKeys := admin.Group("/api-keys") + { + apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup) } } @@ -101,6 +116,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { runtime.GET("/alert", h.Admin.Ops.GetAlertRuntimeSettings) runtime.PUT("/alert", h.Admin.Ops.UpdateAlertRuntimeSettings) + runtime.GET("/logging", h.Admin.Ops.GetRuntimeLogConfig) + runtime.PUT("/logging", h.Admin.Ops.UpdateRuntimeLogConfig) + runtime.POST("/logging/reset", h.Admin.Ops.ResetRuntimeLogConfig) } // Advanced settings (DB-backed) @@ -144,12 +162,18 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // Request drilldown (success + error) ops.GET("/requests", h.Admin.Ops.ListRequestDetails) + // Indexed system logs + ops.GET("/system-logs", h.Admin.Ops.ListSystemLogs) + ops.POST("/system-logs/cleanup", h.Admin.Ops.CleanupSystemLogs) + ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth) + // Dashboard (vNext - raw path for MVP) ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview) ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend) ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram) ops.GET("/dashboard/error-trend", h.Admin.Ops.GetDashboardErrorTrend) ops.GET("/dashboard/error-distribution", h.Admin.Ops.GetDashboardErrorDistribution) + ops.GET("/dashboard/openai-token-stats", h.Admin.Ops.GetDashboardOpenAITokenStats) } } @@ -160,6 +184,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics) dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend) dashboard.GET("/models", h.Admin.Dashboard.GetModelStats) + dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats) dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend) dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) @@ -192,6 +217,7 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { groups.GET("", h.Admin.Group.List) groups.GET("/all", h.Admin.Group.GetAll) + groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder) groups.GET("/:id", h.Admin.Group.GetByID) groups.POST("", h.Admin.Group.Create) groups.PUT("/:id", h.Admin.Group.Update) @@ -207,7 +233,9 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.GET("", h.Admin.Account.List) accounts.GET("/:id", h.Admin.Account.GetByID) accounts.POST("", h.Admin.Account.Create) + accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel) accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS) + accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS) accounts.PUT("/:id", h.Admin.Account.Update) accounts.DELETE("/:id", h.Admin.Account.Delete) accounts.POST("/:id/test", h.Admin.Account.Test) @@ -217,6 +245,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.POST("/:id/clear-error", h.Admin.Account.ClearError) accounts.GET("/:id/usage", h.Admin.Account.GetUsage) accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats) + accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats) accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit) accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable) accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable) @@ -265,6 +294,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } +func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + sora := admin.Group("/sora") + { + sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL) + sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode) + sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken) + sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken) + sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken) + sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken) + sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth) + } +} + func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { gemini := admin.Group("/gemini") { @@ -279,6 +321,7 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { antigravity.POST("/oauth/auth-url", h.Admin.AntigravityOAuth.GenerateAuthURL) antigravity.POST("/oauth/exchange-code", h.Admin.AntigravityOAuth.ExchangeCode) + antigravity.POST("/oauth/refresh-token", h.Admin.AntigravityOAuth.RefreshToken) } } @@ -294,6 +337,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { proxies.PUT("/:id", h.Admin.Proxy.Update) proxies.DELETE("/:id", h.Admin.Proxy.Delete) proxies.POST("/:id/test", h.Admin.Proxy.Test) + proxies.POST("/:id/quality-check", h.Admin.Proxy.CheckQuality) proxies.GET("/:id/stats", h.Admin.Proxy.GetStats) proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts) proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete) @@ -308,6 +352,7 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { codes.GET("/stats", h.Admin.Redeem.GetStats) codes.GET("/export", h.Admin.Redeem.Export) codes.GET("/:id", h.Admin.Redeem.GetByID) + codes.POST("/create-and-redeem", h.Admin.Redeem.CreateAndRedeem) codes.POST("/generate", h.Admin.Redeem.Generate) codes.DELETE("/:id", h.Admin.Redeem.Delete) codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete) @@ -341,6 +386,38 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // 流超时处理配置 adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) + // Sora S3 存储配置 + adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings) + adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings) + adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection) + adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles) + adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile) + adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile) + adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile) + adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile) + } +} + +func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + dataManagement := admin.Group("/data-management") + { + dataManagement.GET("/agent/health", h.Admin.DataManagement.GetAgentHealth) + dataManagement.GET("/config", h.Admin.DataManagement.GetConfig) + dataManagement.PUT("/config", h.Admin.DataManagement.UpdateConfig) + dataManagement.GET("/sources/:source_type/profiles", h.Admin.DataManagement.ListSourceProfiles) + dataManagement.POST("/sources/:source_type/profiles", h.Admin.DataManagement.CreateSourceProfile) + dataManagement.PUT("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.UpdateSourceProfile) + dataManagement.DELETE("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.DeleteSourceProfile) + dataManagement.POST("/sources/:source_type/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveSourceProfile) + dataManagement.POST("/s3/test", h.Admin.DataManagement.TestS3) + dataManagement.GET("/s3/profiles", h.Admin.DataManagement.ListS3Profiles) + dataManagement.POST("/s3/profiles", h.Admin.DataManagement.CreateS3Profile) + dataManagement.PUT("/s3/profiles/:profile_id", h.Admin.DataManagement.UpdateS3Profile) + dataManagement.DELETE("/s3/profiles/:profile_id", h.Admin.DataManagement.DeleteS3Profile) + dataManagement.POST("/s3/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveS3Profile) + dataManagement.POST("/backups", h.Admin.DataManagement.CreateBackupJob) + dataManagement.GET("/backups", h.Admin.DataManagement.ListBackupJobs) + dataManagement.GET("/backups/:job_id", h.Admin.DataManagement.GetBackupJob) } } diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 26d79605..c168820c 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -24,10 +24,19 @@ func RegisterAuthRoutes( // 公开接口 auth := v1.Group("/auth") { - auth.POST("/register", h.Auth.Register) - auth.POST("/login", h.Auth.Login) - auth.POST("/login/2fa", h.Auth.Login2FA) - auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + // 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close) + auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Register) + auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Login) + auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Login2FA) + auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.SendVerifyCode) // Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close) auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, diff --git a/backend/internal/server/routes/auth_rate_limit_integration_test.go b/backend/internal/server/routes/auth_rate_limit_integration_test.go new file mode 100644 index 00000000..8a0ef860 --- /dev/null +++ b/backend/internal/server/routes/auth_rate_limit_integration_test.go @@ -0,0 +1,111 @@ +//go:build integration + +package routes + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + tcredis "github.com/testcontainers/testcontainers-go/modules/redis" +) + +const authRouteRedisImageTag = "redis:8.4-alpine" + +func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) { + ctx := context.Background() + rdb := startAuthRouteRedis(t, ctx) + + router := newAuthRoutesTestRouter(rdb) + const path = "/api/v1/auth/register" + + for i := 1; i <= 6; i++ { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "198.51.100.10:23456" + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if i <= 5 { + require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i) + continue + } + require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流") + require.Contains(t, w.Body.String(), "rate limit exceeded") + } +} + +func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client { + t.Helper() + ensureAuthRouteDockerAvailable(t) + + redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag) + require.NoError(t, err) + t.Cleanup(func() { + _ = redisContainer.Terminate(ctx) + }) + + redisHost, err := redisContainer.Host(ctx) + require.NoError(t, err) + redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp") + require.NoError(t, err) + + rdb := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()), + DB: 0, + }) + require.NoError(t, rdb.Ping(ctx).Err()) + t.Cleanup(func() { + _ = rdb.Close() + }) + return rdb +} + +func ensureAuthRouteDockerAvailable(t *testing.T) { + t.Helper() + if authRouteDockerAvailable() { + return + } + t.Skip("Docker 未启用,跳过认证限流集成测试") +} + +func authRouteDockerAvailable() bool { + if os.Getenv("DOCKER_HOST") != "" { + return true + } + + socketCandidates := []string{ + "/var/run/docker.sock", + filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"), + filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"), + filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"), + filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"), + } + + for _, socket := range socketCandidates { + if socket == "" { + continue + } + if _, err := os.Stat(socket); err == nil { + return true + } + } + return false +} + +func authRouteUserHomeDir() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return home +} diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go new file mode 100644 index 00000000..5ce8497c --- /dev/null +++ b/backend/internal/server/routes/auth_rate_limit_test.go @@ -0,0 +1,67 @@ +package routes + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/handler" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + v1 := router.Group("/api/v1") + + RegisterAuthRoutes( + v1, + &handler.Handlers{ + Auth: &handler.AuthHandler{}, + Setting: &handler.SettingHandler{}, + }, + servermiddleware.JWTAuthMiddleware(func(c *gin.Context) { + c.Next() + }), + redisClient, + ) + + return router +} + +func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:1", + DialTimeout: 50 * time.Millisecond, + ReadTimeout: 50 * time.Millisecond, + WriteTimeout: 50 * time.Millisecond, + }) + t.Cleanup(func() { + _ = rdb.Close() + }) + + router := newAuthRoutesTestRouter(rdb) + paths := []string{ + "/api/v1/auth/register", + "/api/v1/auth/login", + "/api/v1/auth/login/2fa", + "/api/v1/auth/send-verify-code", + } + + for _, path := range paths { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "203.0.113.10:12345" + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path) + require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path) + } +} diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index bf019ce3..6bd91b85 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -1,6 +1,8 @@ package routes import ( + "net/http" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -20,6 +22,11 @@ func RegisterGatewayRoutes( cfg *config.Config, ) { bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) + soraMaxBodySize := cfg.Gateway.SoraMaxBodySize + if soraMaxBodySize <= 0 { + soraMaxBodySize = cfg.Gateway.MaxBodySize + } + soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize) clientRequestID := middleware.ClientRequestID() opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) @@ -36,6 +43,16 @@ func RegisterGatewayRoutes( gateway.GET("/usage", h.Gateway.Usage) // OpenAI Responses API gateway.POST("/responses", h.OpenAIGateway.Responses) + gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) + // 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。 + gateway.POST("/chat/completions", func(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.", + }, + }) + }) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) @@ -53,6 +70,7 @@ func RegisterGatewayRoutes( // OpenAI Responses API(不带v1前缀的别名) r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) + r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ResponsesWebSocket) // Antigravity 模型列表 r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels) @@ -82,4 +100,25 @@ func RegisterGatewayRoutes( antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) } + + // Sora 专用路由(强制使用 sora 平台) + soraV1 := r.Group("/sora/v1") + soraV1.Use(soraBodyLimit) + soraV1.Use(clientRequestID) + soraV1.Use(opsErrorLogger) + soraV1.Use(middleware.ForcePlatform(service.PlatformSora)) + soraV1.Use(gin.HandlerFunc(apiKeyAuth)) + { + soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions) + soraV1.GET("/models", h.Gateway.Models) + } + + // Sora 媒体代理(可选 API Key 验证) + if cfg.Gateway.SoraMediaRequireAPIKey { + r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy) + } else { + r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy) + } + // Sora 媒体代理(签名 URL,无需 API Key) + r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned) } diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go new file mode 100644 index 00000000..40ae0436 --- /dev/null +++ b/backend/internal/server/routes/sora_client.go @@ -0,0 +1,33 @@ +package routes + +import ( + "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + + "github.com/gin-gonic/gin" +) + +// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。 +func RegisterSoraClientRoutes( + v1 *gin.RouterGroup, + h *handler.Handlers, + jwtAuth middleware.JWTAuthMiddleware, +) { + if h.SoraClient == nil { + return + } + + authenticated := v1.Group("/sora") + authenticated.Use(gin.HandlerFunc(jwtAuth)) + { + authenticated.POST("/generate", h.SoraClient.Generate) + authenticated.GET("/generations", h.SoraClient.ListGenerations) + authenticated.GET("/generations/:id", h.SoraClient.GetGeneration) + authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration) + authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration) + authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage) + authenticated.GET("/quota", h.SoraClient.GetQuota) + authenticated.GET("/models", h.SoraClient.GetModels) + authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus) + } +} diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index a6ae8a68..81e91aeb 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -3,11 +3,14 @@ package service import ( "encoding/json" + "hash/fnv" + "reflect" "sort" "strconv" "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/domain" ) @@ -50,6 +53,14 @@ type Account struct { AccountGroups []AccountGroup GroupIDs []int64 Groups []*Group + + // model_mapping 热路径缓存(非持久化字段) + modelMappingCache map[string]string + modelMappingCacheReady bool + modelMappingCacheCredentialsPtr uintptr + modelMappingCacheRawPtr uintptr + modelMappingCacheRawLen int + modelMappingCacheRawSig uint64 } type TempUnschedulableRule struct { @@ -349,6 +360,39 @@ func parseTempUnschedInt(value any) int { } func (a *Account) GetModelMapping() map[string]string { + credentialsPtr := mapPtr(a.Credentials) + rawMapping, _ := a.Credentials["model_mapping"].(map[string]any) + rawPtr := mapPtr(rawMapping) + rawLen := len(rawMapping) + rawSig := uint64(0) + rawSigReady := false + + if a.modelMappingCacheReady && + a.modelMappingCacheCredentialsPtr == credentialsPtr && + a.modelMappingCacheRawPtr == rawPtr && + a.modelMappingCacheRawLen == rawLen { + rawSig = modelMappingSignature(rawMapping) + rawSigReady = true + if a.modelMappingCacheRawSig == rawSig { + return a.modelMappingCache + } + } + + mapping := a.resolveModelMapping(rawMapping) + if !rawSigReady { + rawSig = modelMappingSignature(rawMapping) + } + + a.modelMappingCache = mapping + a.modelMappingCacheReady = true + a.modelMappingCacheCredentialsPtr = credentialsPtr + a.modelMappingCacheRawPtr = rawPtr + a.modelMappingCacheRawLen = rawLen + a.modelMappingCacheRawSig = rawSig + return mapping +} + +func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string { if a.Credentials == nil { // Antigravity 平台使用默认映射 if a.Platform == domain.PlatformAntigravity { @@ -356,25 +400,31 @@ func (a *Account) GetModelMapping() map[string]string { } return nil } - raw, ok := a.Credentials["model_mapping"] - if !ok || raw == nil { + if len(rawMapping) == 0 { // Antigravity 平台使用默认映射 if a.Platform == domain.PlatformAntigravity { return domain.DefaultAntigravityModelMapping } return nil } - if m, ok := raw.(map[string]any); ok { - result := make(map[string]string) - for k, v := range m { - if s, ok := v.(string); ok { - result[k] = s - } - } - if len(result) > 0 { - return result + + result := make(map[string]string) + for k, v := range rawMapping { + if s, ok := v.(string); ok { + result[k] = s } } + if len(result) > 0 { + if a.Platform == domain.PlatformAntigravity { + ensureAntigravityDefaultPassthroughs(result, []string{ + "gemini-3-flash", + "gemini-3.1-pro-high", + "gemini-3.1-pro-low", + }) + } + return result + } + // Antigravity 平台使用默认映射 if a.Platform == domain.PlatformAntigravity { return domain.DefaultAntigravityModelMapping @@ -382,6 +432,58 @@ func (a *Account) GetModelMapping() map[string]string { return nil } +func mapPtr(m map[string]any) uintptr { + if m == nil { + return 0 + } + return reflect.ValueOf(m).Pointer() +} + +func modelMappingSignature(rawMapping map[string]any) uint64 { + if len(rawMapping) == 0 { + return 0 + } + keys := make([]string, 0, len(rawMapping)) + for k := range rawMapping { + keys = append(keys, k) + } + sort.Strings(keys) + + h := fnv.New64a() + for _, k := range keys { + _, _ = h.Write([]byte(k)) + _, _ = h.Write([]byte{0}) + if v, ok := rawMapping[k].(string); ok { + _, _ = h.Write([]byte(v)) + } else { + _, _ = h.Write([]byte{1}) + } + _, _ = h.Write([]byte{0xff}) + } + return h.Sum64() +} + +func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) { + if mapping == nil || model == "" { + return + } + if _, exists := mapping[model]; exists { + return + } + for pattern := range mapping { + if matchWildcard(pattern, model) { + return + } + } + mapping[model] = model +} + +func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []string) { + for _, model := range models { + ensureAntigravityDefaultPassthrough(mapping, model) + } +} + // IsModelSupported 检查模型是否在 model_mapping 中(支持通配符) // 如果未配置 mapping,返回 true(允许所有模型) func (a *Account) IsModelSupported(requestedModel string) bool { @@ -425,6 +527,22 @@ func (a *Account) GetBaseURL() string { if baseURL == "" { return "https://api.anthropic.com" } + if a.Platform == PlatformAntigravity { + return strings.TrimRight(baseURL, "/") + "/antigravity" + } + return baseURL +} + +// GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。 +// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。 +func (a *Account) GetGeminiBaseURL(defaultBaseURL string) string { + baseURL := strings.TrimSpace(a.GetCredential("base_url")) + if baseURL == "" { + return defaultBaseURL + } + if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey { + return strings.TrimRight(baseURL, "/") + "/antigravity" + } return baseURL } @@ -680,6 +798,204 @@ func (a *Account) IsMixedSchedulingEnabled() bool { return false } +// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。 +// +// 新字段:accounts.extra.openai_passthrough。 +// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。 +// 字段缺失或类型不正确时,按 false(关闭)处理。 +func (a *Account) IsOpenAIPassthroughEnabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + if enabled, ok := a.Extra["openai_passthrough"].(bool); ok { + return enabled + } + if enabled, ok := a.Extra["openai_oauth_passthrough"].(bool); ok { + return enabled + } + return false +} + +// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。 +// +// 分类型新字段: +// - OAuth 账号:accounts.extra.openai_oauth_responses_websockets_v2_enabled +// - API Key 账号:accounts.extra.openai_apikey_responses_websockets_v2_enabled +// +// 兼容字段: +// - accounts.extra.responses_websockets_v2_enabled +// - accounts.extra.openai_ws_enabled(历史开关) +// +// 优先级: +// 1. 按账号类型读取分类型字段 +// 2. 分类型字段缺失时,回退兼容字段 +func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + if a.IsOpenAIOAuth() { + if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok { + return enabled + } + } + if a.IsOpenAIApiKey() { + if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok { + return enabled + } + } + if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok { + return enabled + } + if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok { + return enabled + } + return false +} + +const ( + OpenAIWSIngressModeOff = "off" + OpenAIWSIngressModeShared = "shared" + OpenAIWSIngressModeDedicated = "dedicated" +) + +func normalizeOpenAIWSIngressMode(mode string) string { + switch strings.ToLower(strings.TrimSpace(mode)) { + case OpenAIWSIngressModeOff: + return OpenAIWSIngressModeOff + case OpenAIWSIngressModeShared: + return OpenAIWSIngressModeShared + case OpenAIWSIngressModeDedicated: + return OpenAIWSIngressModeDedicated + default: + return "" + } +} + +func normalizeOpenAIWSIngressDefaultMode(mode string) string { + if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" { + return normalized + } + return OpenAIWSIngressModeShared +} + +// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。 +// +// 优先级: +// 1. 分类型 mode 新字段(string) +// 2. 分类型 enabled 旧字段(bool) +// 3. 兼容 enabled 旧字段(bool) +// 4. defaultMode(非法时回退 shared) +func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string { + resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode) + if a == nil || !a.IsOpenAI() { + return OpenAIWSIngressModeOff + } + if a.Extra == nil { + return resolvedDefault + } + + resolveModeString := func(key string) (string, bool) { + raw, ok := a.Extra[key] + if !ok { + return "", false + } + mode, ok := raw.(string) + if !ok { + return "", false + } + normalized := normalizeOpenAIWSIngressMode(mode) + if normalized == "" { + return "", false + } + return normalized, true + } + resolveBoolMode := func(key string) (string, bool) { + raw, ok := a.Extra[key] + if !ok { + return "", false + } + enabled, ok := raw.(bool) + if !ok { + return "", false + } + if enabled { + return OpenAIWSIngressModeShared, true + } + return OpenAIWSIngressModeOff, true + } + + if a.IsOpenAIOAuth() { + if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok { + return mode + } + if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok { + return mode + } + } + if a.IsOpenAIApiKey() { + if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok { + return mode + } + if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok { + return mode + } + } + if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok { + return mode + } + if mode, ok := resolveBoolMode("openai_ws_enabled"); ok { + return mode + } + return resolvedDefault +} + +// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。 +// 字段:accounts.extra.openai_ws_force_http。 +func (a *Account) IsOpenAIWSForceHTTPEnabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + enabled, ok := a.Extra["openai_ws_force_http"].(bool) + return ok && enabled +} + +// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。 +// 字段:accounts.extra.openai_ws_allow_store_recovery。 +func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool) + return ok && enabled +} + +// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。 +func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool { + return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled() +} + +// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。 +// 字段:accounts.extra.anthropic_passthrough。 +// 字段缺失或类型不正确时,按 false(关闭)处理。 +func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool { + if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil { + return false + } + enabled, ok := a.Extra["anthropic_passthrough"].(bool) + return ok && enabled +} + +// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。 +// 字段:accounts.extra.codex_cli_only。 +// 字段缺失或类型不正确时,按 false(关闭)处理。 +func (a *Account) IsCodexCLIOnlyEnabled() bool { + if a == nil || !a.IsOpenAIOAuth() || a.Extra == nil { + return false + } + enabled, ok := a.Extra["codex_cli_only"].(bool) + return ok && enabled +} + // WindowCostSchedulability 窗口费用调度状态 type WindowCostSchedulability int @@ -717,6 +1033,26 @@ func (a *Account) IsTLSFingerprintEnabled() bool { return false } +// GetUserMsgQueueMode 获取用户消息队列模式 +// "serialize" = 串行队列, "throttle" = 软性限速, "" = 未设置(使用全局配置) +func (a *Account) GetUserMsgQueueMode() string { + if a.Extra == nil { + return "" + } + // 优先读取新字段 user_msg_queue_mode(白名单校验,非法值视为未设置) + if mode, ok := a.Extra["user_msg_queue_mode"].(string); ok && mode != "" { + if mode == config.UMQModeSerialize || mode == config.UMQModeThrottle { + return mode + } + return "" // 非法值 fallback 到全局配置 + } + // 向后兼容: user_msg_queue_enabled: true → "serialize" + if enabled, ok := a.Extra["user_msg_queue_enabled"].(bool); ok && enabled { + return config.UMQModeSerialize + } + return "" +} + // IsSessionIDMaskingEnabled 检查是否启用会话ID伪装 // 仅适用于 Anthropic OAuth/SetupToken 类型账号 // 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID, @@ -736,6 +1072,38 @@ func (a *Account) IsSessionIDMaskingEnabled() bool { return false } +// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换 +// 仅适用于 Anthropic OAuth/SetupToken 类型账号 +// 启用后将所有 cache creation tokens 归入指定的 TTL 类型(5m 或 1h) +func (a *Account) IsCacheTTLOverrideEnabled() bool { + if !a.IsAnthropicOAuthOrSetupToken() { + return false + } + if a.Extra == nil { + return false + } + if v, ok := a.Extra["cache_ttl_override_enabled"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + +// GetCacheTTLOverrideTarget 获取缓存 TTL 强制替换的目标类型 +// 返回 "5m" 或 "1h",默认 "5m" +func (a *Account) GetCacheTTLOverrideTarget() string { + if a.Extra == nil { + return "5m" + } + if v, ok := a.Extra["cache_ttl_override_target"]; ok { + if target, ok := v.(string); ok && (target == "5m" || target == "1h") { + return target + } + } + return "5m" +} + // GetWindowCostLimit 获取 5h 窗口费用阈值(美元) // 返回 0 表示未启用 func (a *Account) GetWindowCostLimit() float64 { @@ -790,6 +1158,80 @@ func (a *Account) GetSessionIdleTimeoutMinutes() int { return 5 } +// GetBaseRPM 获取基础 RPM 限制 +// 返回 0 表示未启用(负数视为无效配置,按 0 处理) +func (a *Account) GetBaseRPM() int { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra["base_rpm"]; ok { + val := parseExtraInt(v) + if val > 0 { + return val + } + } + return 0 +} + +// GetRPMStrategy 获取 RPM 策略 +// "tiered" = 三区模型(默认), "sticky_exempt" = 粘性豁免 +func (a *Account) GetRPMStrategy() string { + if a.Extra == nil { + return "tiered" + } + if v, ok := a.Extra["rpm_strategy"]; ok { + if s, ok := v.(string); ok && s == "sticky_exempt" { + return "sticky_exempt" + } + } + return "tiered" +} + +// GetRPMStickyBuffer 获取 RPM 粘性缓冲数量 +// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1) +func (a *Account) GetRPMStickyBuffer() int { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra["rpm_sticky_buffer"]; ok { + val := parseExtraInt(v) + if val > 0 { + return val + } + } + base := a.GetBaseRPM() + buffer := base / 5 + if buffer < 1 && base > 0 { + buffer = 1 + } + return buffer +} + +// CheckRPMSchedulability 根据当前 RPM 计数检查调度状态 +// 复用 WindowCostSchedulability 三态:Schedulable / StickyOnly / NotSchedulable +func (a *Account) CheckRPMSchedulability(currentRPM int) WindowCostSchedulability { + baseRPM := a.GetBaseRPM() + if baseRPM <= 0 { + return WindowCostSchedulable + } + + if currentRPM < baseRPM { + return WindowCostSchedulable + } + + strategy := a.GetRPMStrategy() + if strategy == "sticky_exempt" { + return WindowCostStickyOnly // 粘性豁免无红区 + } + + // tiered: 黄区 + 红区 + buffer := a.GetRPMStickyBuffer() + if currentRPM < baseRPM+buffer { + return WindowCostStickyOnly + } + return WindowCostNotSchedulable +} + // CheckWindowCostSchedulability 根据当前窗口费用检查调度状态 // - 费用 < 阈值: WindowCostSchedulable(可正常调度) // - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话) @@ -853,6 +1295,12 @@ func parseExtraFloat64(value any) float64 { } // parseExtraInt 从 extra 字段解析 int 值 +// ParseExtraInt 从 extra 字段的 any 值解析为 int。 +// 支持 int, int64, float64, json.Number, string 类型,无法解析时返回 0。 +func ParseExtraInt(value any) int { + return parseExtraInt(value) +} + func parseExtraInt(value any) int { switch v := value.(type) { case int: diff --git a/backend/internal/service/account_anthropic_passthrough_test.go b/backend/internal/service/account_anthropic_passthrough_test.go new file mode 100644 index 00000000..e66407a3 --- /dev/null +++ b/backend/internal/service/account_anthropic_passthrough_test.go @@ -0,0 +1,62 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAccount_IsAnthropicAPIKeyPassthroughEnabled(t *testing.T) { + t.Run("Anthropic API Key 开启", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + } + require.True(t, account.IsAnthropicAPIKeyPassthroughEnabled()) + }) + + t.Run("Anthropic API Key 关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "anthropic_passthrough": false, + }, + } + require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled()) + }) + + t.Run("字段类型非法默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "anthropic_passthrough": "true", + }, + } + require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled()) + }) + + t.Run("非 Anthropic API Key 账号始终关闭", func(t *testing.T) { + oauth := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + } + require.False(t, oauth.IsAnthropicAPIKeyPassthroughEnabled()) + + openai := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + } + require.False(t, openai.IsAnthropicAPIKeyPassthroughEnabled()) + }) +} diff --git a/backend/internal/service/account_base_url_test.go b/backend/internal/service/account_base_url_test.go new file mode 100644 index 00000000..a1322193 --- /dev/null +++ b/backend/internal/service/account_base_url_test.go @@ -0,0 +1,160 @@ +//go:build unit + +package service + +import ( + "testing" +) + +func TestGetBaseURL(t *testing.T) { + tests := []struct { + name string + account Account + expected string + }{ + { + name: "non-apikey type returns empty", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAnthropic, + }, + expected: "", + }, + { + name: "apikey without base_url returns default anthropic", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAnthropic, + Credentials: map[string]any{}, + }, + expected: "https://api.anthropic.com", + }, + { + name: "apikey with custom base_url", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAnthropic, + Credentials: map[string]any{"base_url": "https://custom.example.com"}, + }, + expected: "https://custom.example.com", + }, + { + name: "antigravity apikey auto-appends /antigravity", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity apikey trims trailing slash before appending", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com/"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity non-apikey returns empty", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetBaseURL() + if result != tt.expected { + t.Errorf("GetBaseURL() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestGetGeminiBaseURL(t *testing.T) { + const defaultGeminiURL = "https://generativelanguage.googleapis.com" + + tests := []struct { + name string + account Account + expected string + }{ + { + name: "apikey without base_url returns default", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{}, + }, + expected: defaultGeminiURL, + }, + { + name: "apikey with custom base_url", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{"base_url": "https://custom-gemini.example.com"}, + }, + expected: "https://custom-gemini.example.com", + }, + { + name: "antigravity apikey auto-appends /antigravity", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity apikey trims trailing slash", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com/"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity oauth does NOT append /antigravity", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com", + }, + { + name: "oauth without base_url returns default", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{}, + }, + expected: defaultGeminiURL, + }, + { + name: "nil credentials returns default", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + expected: defaultGeminiURL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetGeminiBaseURL(defaultGeminiURL) + if result != tt.expected { + t.Errorf("GetGeminiBaseURL() = %q, want %q", result, tt.expected) + } + }) + } +} diff --git a/backend/internal/service/account_intercept_warmup_test.go b/backend/internal/service/account_intercept_warmup_test.go new file mode 100644 index 00000000..f117fd8d --- /dev/null +++ b/backend/internal/service/account_intercept_warmup_test.go @@ -0,0 +1,66 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAccount_IsInterceptWarmupEnabled(t *testing.T) { + tests := []struct { + name string + credentials map[string]any + expected bool + }{ + { + name: "nil credentials", + credentials: nil, + expected: false, + }, + { + name: "empty map", + credentials: map[string]any{}, + expected: false, + }, + { + name: "field not present", + credentials: map[string]any{"access_token": "tok"}, + expected: false, + }, + { + name: "field is true", + credentials: map[string]any{"intercept_warmup_requests": true}, + expected: true, + }, + { + name: "field is false", + credentials: map[string]any{"intercept_warmup_requests": false}, + expected: false, + }, + { + name: "field is string true", + credentials: map[string]any{"intercept_warmup_requests": "true"}, + expected: false, + }, + { + name: "field is int 1", + credentials: map[string]any{"intercept_warmup_requests": 1}, + expected: false, + }, + { + name: "field is nil", + credentials: map[string]any{"intercept_warmup_requests": nil}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Credentials: tt.credentials} + result := a.IsInterceptWarmupEnabled() + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go new file mode 100644 index 00000000..a85c68ec --- /dev/null +++ b/backend/internal/service/account_openai_passthrough_test.go @@ -0,0 +1,294 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAccount_IsOpenAIPassthroughEnabled(t *testing.T) { + t.Run("新字段开启", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.True(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("兼容旧字段", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_passthrough": true, + }, + } + require.True(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("非OpenAI账号始终关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.False(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("空额外配置默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + require.False(t, account.IsOpenAIPassthroughEnabled()) + }) +} + +func TestAccount_IsOpenAIOAuthPassthroughEnabled(t *testing.T) { + t.Run("仅OAuth类型允许返回开启", func(t *testing.T) { + oauthAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.True(t, oauthAccount.IsOpenAIOAuthPassthroughEnabled()) + + apiKeyAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.False(t, apiKeyAccount.IsOpenAIOAuthPassthroughEnabled()) + }) +} + +func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) { + t.Run("OpenAI OAuth 开启", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_cli_only": true, + }, + } + require.True(t, account.IsCodexCLIOnlyEnabled()) + }) + + t.Run("OpenAI OAuth 关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_cli_only": false, + }, + } + require.False(t, account.IsCodexCLIOnlyEnabled()) + }) + + t.Run("字段缺失默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{}, + } + require.False(t, account.IsCodexCLIOnlyEnabled()) + }) + + t.Run("类型非法默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_cli_only": "true", + }, + } + require.False(t, account.IsCodexCLIOnlyEnabled()) + }) + + t.Run("非 OAuth 账号始终关闭", func(t *testing.T) { + apiKeyAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "codex_cli_only": true, + }, + } + require.False(t, apiKeyAccount.IsCodexCLIOnlyEnabled()) + + otherPlatform := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_cli_only": true, + }, + } + require.False(t, otherPlatform.IsCodexCLIOnlyEnabled()) + }) +} + +func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) { + t.Run("OAuth使用OAuth专用开关", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + }, + } + require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("API Key使用API Key专用开关", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("分类型新键优先于兼容键", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": false, + "responses_websockets_v2_enabled": true, + "openai_ws_enabled": true, + }, + } + require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("分类型键缺失时回退兼容键", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("非OpenAI账号默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) +} + +func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { + t.Run("default fallback to shared", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{}, + } + require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("")) + require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) + }) + + t.Run("oauth mode field has highest priority", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + "openai_oauth_responses_websockets_v2_enabled": false, + "responses_websockets_v2_enabled": false, + }, + } + require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + }) + + t.Run("legacy enabled maps to shared", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + }) + + t.Run("legacy disabled maps to off", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": false, + "responses_websockets_v2_enabled": true, + }, + } + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + }) + + t.Run("non openai always off", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated)) + }) +} + +func TestAccount_OpenAIWSExtraFlags(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_ws_force_http": true, + "openai_ws_allow_store_recovery": true, + }, + } + require.True(t, account.IsOpenAIWSForceHTTPEnabled()) + require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled()) + + off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}} + require.False(t, off.IsOpenAIWSForceHTTPEnabled()) + require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled()) + + var nilAccount *Account + require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled()) + + nonOpenAI := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_ws_allow_store_recovery": true, + }, + } + require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled()) +} diff --git a/backend/internal/service/account_rpm_test.go b/backend/internal/service/account_rpm_test.go new file mode 100644 index 00000000..9d91f3e0 --- /dev/null +++ b/backend/internal/service/account_rpm_test.go @@ -0,0 +1,120 @@ +package service + +import ( + "encoding/json" + "testing" +) + +func TestGetBaseRPM(t *testing.T) { + tests := []struct { + name string + extra map[string]any + expected int + }{ + {"nil extra", nil, 0}, + {"no key", map[string]any{}, 0}, + {"zero", map[string]any{"base_rpm": 0}, 0}, + {"int value", map[string]any{"base_rpm": 15}, 15}, + {"float value", map[string]any{"base_rpm": 15.0}, 15}, + {"string value", map[string]any{"base_rpm": "15"}, 15}, + {"negative value", map[string]any{"base_rpm": -5}, 0}, + {"int64 value", map[string]any{"base_rpm": int64(20)}, 20}, + {"json.Number value", map[string]any{"base_rpm": json.Number("25")}, 25}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Extra: tt.extra} + if got := a.GetBaseRPM(); got != tt.expected { + t.Errorf("GetBaseRPM() = %d, want %d", got, tt.expected) + } + }) + } +} + +func TestGetRPMStrategy(t *testing.T) { + tests := []struct { + name string + extra map[string]any + expected string + }{ + {"nil extra", nil, "tiered"}, + {"no key", map[string]any{}, "tiered"}, + {"tiered", map[string]any{"rpm_strategy": "tiered"}, "tiered"}, + {"sticky_exempt", map[string]any{"rpm_strategy": "sticky_exempt"}, "sticky_exempt"}, + {"invalid", map[string]any{"rpm_strategy": "foobar"}, "tiered"}, + {"empty string fallback", map[string]any{"rpm_strategy": ""}, "tiered"}, + {"numeric value fallback", map[string]any{"rpm_strategy": 123}, "tiered"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Extra: tt.extra} + if got := a.GetRPMStrategy(); got != tt.expected { + t.Errorf("GetRPMStrategy() = %q, want %q", got, tt.expected) + } + }) + } +} + +func TestCheckRPMSchedulability(t *testing.T) { + tests := []struct { + name string + extra map[string]any + currentRPM int + expected WindowCostSchedulability + }{ + {"disabled", map[string]any{}, 100, WindowCostSchedulable}, + {"green zone", map[string]any{"base_rpm": 15}, 10, WindowCostSchedulable}, + {"yellow zone tiered", map[string]any{"base_rpm": 15}, 15, WindowCostStickyOnly}, + {"red zone tiered", map[string]any{"base_rpm": 15}, 18, WindowCostNotSchedulable}, + {"sticky_exempt at limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 15, WindowCostStickyOnly}, + {"sticky_exempt over limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 100, WindowCostStickyOnly}, + {"custom buffer", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 14, WindowCostStickyOnly}, + {"custom buffer red", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 15, WindowCostNotSchedulable}, + {"base_rpm=1 green", map[string]any{"base_rpm": 1}, 0, WindowCostSchedulable}, + {"base_rpm=1 yellow (at limit)", map[string]any{"base_rpm": 1}, 1, WindowCostStickyOnly}, + {"base_rpm=1 red (at limit+buffer)", map[string]any{"base_rpm": 1}, 2, WindowCostNotSchedulable}, + {"negative currentRPM", map[string]any{"base_rpm": 15}, -1, WindowCostSchedulable}, + {"base_rpm negative disabled", map[string]any{"base_rpm": -5}, 10, WindowCostSchedulable}, + {"very high currentRPM", map[string]any{"base_rpm": 10}, 9999, WindowCostNotSchedulable}, + {"sticky_exempt very high currentRPM", map[string]any{"base_rpm": 10, "rpm_strategy": "sticky_exempt"}, 9999, WindowCostStickyOnly}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Extra: tt.extra} + if got := a.CheckRPMSchedulability(tt.currentRPM); got != tt.expected { + t.Errorf("CheckRPMSchedulability(%d) = %d, want %d", tt.currentRPM, got, tt.expected) + } + }) + } +} + +func TestGetRPMStickyBuffer(t *testing.T) { + tests := []struct { + name string + extra map[string]any + expected int + }{ + {"nil extra", nil, 0}, + {"no keys", map[string]any{}, 0}, + {"base_rpm=0", map[string]any{"base_rpm": 0}, 0}, + {"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1}, + {"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1}, + {"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1}, + {"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2}, + {"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3}, + {"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20}, + {"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5}, + {"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2}, + {"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2}, + {"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7}, + {"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Extra: tt.extra} + if got := a.GetRPMStickyBuffer(); got != tt.expected { + t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected) + } + }) + } +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 90365d2f..a3707184 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -25,11 +25,17 @@ type AccountRepository interface { // GetByCRSAccountID finds an account previously synced from CRS. // Returns (nil, nil) if not found. GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) + // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora') + // 用于查找通过 linked_openai_account_id 关联的 Sora 账号 + FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) + // ListCRSAccountIDs returns a map of crs_account_id -> local account ID + // for all accounts that have been synced from CRS. + ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) Update(ctx context.Context, account *Account) error Delete(ctx context.Context, id int64) error List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListActive(ctx context.Context) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error) @@ -50,7 +56,6 @@ type AccountRepository interface { ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error - SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error @@ -114,6 +119,10 @@ type AccountService struct { groupRepo GroupRepository } +type groupExistenceBatchChecker interface { + ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) +} + // NewAccountService 创建账号服务实例 func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService { return &AccountService{ @@ -126,11 +135,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) { // 验证分组是否存在(如果指定了分组) if len(req.GroupIDs) > 0 { - for _, groupID := range req.GroupIDs { - _, err := s.groupRepo.GetByID(ctx, groupID) - if err != nil { - return nil, fmt.Errorf("get group: %w", err) - } + if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil { + return nil, err } } @@ -251,11 +257,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount // 先验证分组是否存在(在任何写操作之前) if req.GroupIDs != nil { - for _, groupID := range *req.GroupIDs { - _, err := s.groupRepo.GetByID(ctx, groupID) - if err != nil { - return nil, fmt.Errorf("get group: %w", err) - } + if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil { + return nil, err } } @@ -295,6 +298,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error { return nil } +func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + if s.groupRepo == nil { + return fmt.Errorf("group repository not configured") + } + + if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok { + existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs) + if err != nil { + return fmt.Errorf("check groups exists: %w", err) + } + for _, groupID := range groupIDs { + if groupID <= 0 { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + if !existsByID[groupID] { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + } + return nil + } + + for _, groupID := range groupIDs { + _, err := s.groupRepo.GetByID(ctx, groupID) + if err != nil { + return fmt.Errorf("get group: %w", err) + } + } + return nil +} + // UpdateStatus 更新账号状态 func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error { account, err := s.accountRepo.GetByID(ctx, id) diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index e5eabfc6..a466b68a 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -54,6 +54,14 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st panic("unexpected GetByCRSAccountID call") } +func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { + panic("unexpected FindByExtraField call") +} + +func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { + panic("unexpected ListCRSAccountIDs call") +} + func (s *accountRepoStub) Update(ctx context.Context, account *Account) error { panic("unexpected Update call") } @@ -71,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination panic("unexpected List call") } -func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { panic("unexpected ListWithFilters call") } @@ -143,10 +151,6 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt panic("unexpected SetRateLimited call") } -func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { - panic("unexpected SetAntigravityQuotaScopeLimit call") -} - func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { panic("unexpected SetModelRateLimit call") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 3290fe52..c55e418d 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -12,13 +12,17 @@ import ( "io" "log" "net/http" + "net/url" "regexp" "strings" + "sync" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -31,6 +35,11 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`) const ( testClaudeAPIURL = "https://api.anthropic.com/v1/messages" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" + soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 + soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions" + soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine" + soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap" + soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check" ) // TestEvent represents a SSE event for account testing @@ -38,6 +47,9 @@ type TestEvent struct { Type string `json:"type"` Text string `json:"text,omitempty"` Model string `json:"model,omitempty"` + Status string `json:"status,omitempty"` + Code string `json:"code,omitempty"` + Data any `json:"data,omitempty"` Success bool `json:"success,omitempty"` Error string `json:"error,omitempty"` } @@ -49,8 +61,13 @@ type AccountTestService struct { antigravityGatewayService *AntigravityGatewayService httpUpstream HTTPUpstream cfg *config.Config + soraTestGuardMu sync.Mutex + soraTestLastRun map[int64]time.Time + soraTestCooldown time.Duration } +const defaultSoraTestCooldown = 10 * time.Second + // NewAccountTestService creates a new AccountTestService func NewAccountTestService( accountRepo AccountRepository, @@ -65,6 +82,8 @@ func NewAccountTestService( antigravityGatewayService: antigravityGatewayService, httpUpstream: httpUpstream, cfg: cfg, + soraTestLastRun: make(map[int64]time.Time), + soraTestCooldown: defaultSoraTestCooldown, } } @@ -163,6 +182,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int return s.testAntigravityAccountConnection(c, account, modelID) } + if account.Platform == PlatformSora { + return s.testSoraAccountConnection(c, account) + } + return s.testClaudeAccountConnection(c, account, modelID) } @@ -245,7 +268,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account // Set common headers req.Header.Set("Content-Type", "application/json") req.Header.Set("anthropic-version", "2023-06-01") - req.Header.Set("anthropic-beta", claude.DefaultBetaHeader) // Apply Claude Code client headers for key, value := range claude.DefaultHeaders { @@ -254,8 +276,10 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account // Set authentication header if useBearer { + req.Header.Set("anthropic-beta", claude.DefaultBetaHeader) req.Header.Set("Authorization", "Bearer "+authToken) } else { + req.Header.Set("anthropic-beta", claude.APIKeyBetaHeader) req.Header.Set("x-api-key", authToken) } @@ -461,6 +485,697 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account return s.processGeminiStream(c, resp.Body) } +type soraProbeStep struct { + Name string `json:"name"` + Status string `json:"status"` + HTTPStatus int `json:"http_status,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + Message string `json:"message,omitempty"` +} + +type soraProbeSummary struct { + Status string `json:"status"` + Steps []soraProbeStep `json:"steps"` +} + +type soraProbeRecorder struct { + steps []soraProbeStep +} + +func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) { + r.steps = append(r.steps, soraProbeStep{ + Name: name, + Status: status, + HTTPStatus: httpStatus, + ErrorCode: strings.TrimSpace(errorCode), + Message: strings.TrimSpace(message), + }) +} + +func (r *soraProbeRecorder) finalize() soraProbeSummary { + meSuccess := false + partial := false + for _, step := range r.steps { + if step.Name == "me" { + meSuccess = strings.EqualFold(step.Status, "success") + continue + } + if strings.EqualFold(step.Status, "failed") { + partial = true + } + } + + status := "success" + if !meSuccess { + status = "failed" + } else if partial { + status = "partial_success" + } + + return soraProbeSummary{ + Status: status, + Steps: append([]soraProbeStep(nil), r.steps...), + } +} + +func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) { + if rec == nil { + return + } + summary := rec.finalize() + code := "" + for _, step := range summary.Steps { + if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" { + code = step.ErrorCode + break + } + } + s.sendEvent(c, TestEvent{ + Type: "sora_test_result", + Status: summary.Status, + Code: code, + Data: summary, + }) +} + +func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) { + if accountID <= 0 { + return 0, true + } + s.soraTestGuardMu.Lock() + defer s.soraTestGuardMu.Unlock() + + if s.soraTestLastRun == nil { + s.soraTestLastRun = make(map[int64]time.Time) + } + cooldown := s.soraTestCooldown + if cooldown <= 0 { + cooldown = defaultSoraTestCooldown + } + + now := time.Now() + if lastRun, ok := s.soraTestLastRun[accountID]; ok { + elapsed := now.Sub(lastRun) + if elapsed < cooldown { + return cooldown - elapsed, false + } + } + s.soraTestLastRun[accountID] = now + return 0, true +} + +func ceilSeconds(d time.Duration) int { + if d <= 0 { + return 1 + } + sec := int(d / time.Second) + if d%time.Second != 0 { + sec++ + } + if sec < 1 { + sec = 1 + } + return sec +} + +// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。 +// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。 +func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error { + ctx := c.Request.Context() + + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证") + } + + baseURL := account.GetBaseURL() + if baseURL == "" { + return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url") + } + + // 验证 base_url 格式 + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error())) + } + upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions" + + // 设置 SSE 头 + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { + msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) + return s.sendErrorAndEnd(c, msg) + } + + s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"}) + + // 构建轻量级 prompt-enhance 请求作为连通性测试 + testPayload := map[string]any{ + "model": "prompt-enhance-short-10s", + "messages": []map[string]string{{"role": "user", "content": "test"}}, + "stream": false, + } + payloadBytes, _ := json.Marshal(testPayload) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes)) + if err != nil { + return s.sendErrorAndEnd(c, "构建测试请求失败") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + // 获取代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + + if resp.StatusCode == http.StatusOK { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)}) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode)) + } + + // 其他错误但能连通(如 400 参数错误)也算连通性测试通过 + if resp.StatusCode == http.StatusBadRequest { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)}) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256))) +} + +// testSoraAccountConnection 测试 Sora 账号的连接 +// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性 +// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性 +func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error { + // apikey 类型走独立测试流程 + if account.Type == AccountTypeAPIKey { + return s.testSoraAPIKeyAccountConnection(c, account) + } + + ctx := c.Request.Context() + recorder := &soraProbeRecorder{} + + authToken := account.GetCredential("access_token") + if authToken == "" { + recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "No access token available") + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { + msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) + recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, msg) + } + + // Send test_start event + s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"}) + + req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil) + if err != nil { + recorder.addStep("me", "failed", 0, "request_build_failed", err.Error()) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Failed to create request") + } + + // 使用 Sora 客户端标准请求头 + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + + // Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint() + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) + if err != nil { + recorder.addStep("me", "failed", 0, "network_error", err.Error()) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { + recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected") + s.emitSoraProbeSummary(c, recorder) + s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body) + return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body)) + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body) + switch { + case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"): + recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号") + case strings.EqualFold(upstreamCode, "unsupported_country_code"): + recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试") + case strings.TrimSpace(upstreamMessage) != "": + recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage)) + default: + recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512))) + } + } + recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok") + + // 解析 /me 响应,提取用户信息 + var meResp map[string]any + if err := json.Unmarshal(body, &meResp); err != nil { + // 能收到 200 就说明 token 有效 + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"}) + } else { + // 尝试提取用户名或邮箱信息 + info := "Sora connection OK" + if name, ok := meResp["name"].(string); ok && name != "" { + info = fmt.Sprintf("Sora connection OK - User: %s", name) + } else if email, ok := meResp["email"].(string); ok && email != "" { + info = fmt.Sprintf("Sora connection OK - Email: %s", email) + } + s.sendEvent(c, TestEvent{Type: "content", Text: info}) + } + + // 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试) + subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil) + if err == nil { + subReq.Header.Set("Authorization", "Bearer "+authToken) + subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + subReq.Header.Set("Accept", "application/json") + subReq.Header.Set("Accept-Language", "en-US,en;q=0.9") + subReq.Header.Set("Origin", "https://sora.chatgpt.com") + subReq.Header.Set("Referer", "https://sora.chatgpt.com/") + + subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) + if subErr != nil { + recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error()) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())}) + } else { + subBody, _ := io.ReadAll(subResp.Body) + _ = subResp.Body.Close() + if subResp.StatusCode == http.StatusOK { + recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok") + if summary := parseSoraSubscriptionSummary(subBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"}) + } + } else { + if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) { + recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected") + s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)}) + } else { + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody) + recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)}) + } + } + } + } + + // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。 + s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder) + + s.emitSoraProbeSummary(c, recorder) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + +func (s *AccountTestService) testSora2Capabilities( + c *gin.Context, + ctx context.Context, + account *Account, + authToken string, + proxyURL string, + enableTLSFingerprint bool, + recorder *soraProbeRecorder, +) { + inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraInviteMineURL, + proxyURL, + enableTLSFingerprint, + ) + if err != nil { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())}) + return + } + + if inviteStatus == http.StatusUnauthorized { + bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraBootstrapURL, + proxyURL, + enableTLSFingerprint, + ) + if bootstrapErr == nil && bootstrapStatus == http.StatusOK { + if recorder != nil { + recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok") + } + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"}) + inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraInviteMineURL, + proxyURL, + enableTLSFingerprint, + ) + if err != nil { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())}) + return + } + } else if recorder != nil { + code := "" + msg := "" + if bootstrapErr != nil { + code = "network_error" + msg = bootstrapErr.Error() + } + recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg) + } + } + + if inviteStatus != http.StatusOK { + if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected") + } + s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)}) + return + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody) + if recorder != nil { + recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)}) + return + } + if recorder != nil { + recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok") + } + + if summary := parseSoraInviteSummary(inviteBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"}) + } + + remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraRemainingURL, + proxyURL, + enableTLSFingerprint, + ) + if remainingErr != nil { + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error()) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())}) + return + } + if remainingStatus != http.StatusOK { + if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) { + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected") + } + s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)}) + return + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody) + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)}) + return + } + if recorder != nil { + recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok") + } + if summary := parseSoraRemainingSummary(remainingBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"}) + } +} + +func (s *AccountTestService) fetchSoraTestEndpoint( + ctx context.Context, + account *Account, + authToken string, + url string, + proxyURL string, + enableTLSFingerprint bool, +) (int, http.Header, []byte, error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return 0, nil, nil, err + } + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableTLSFingerprint) + if err != nil { + return 0, nil, nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return resp.StatusCode, resp.Header, nil, readErr + } + return resp.StatusCode, resp.Header, body, nil +} + +func parseSoraSubscriptionSummary(body []byte) string { + var subResp struct { + Data []struct { + Plan struct { + ID string `json:"id"` + Title string `json:"title"` + } `json:"plan"` + EndTS string `json:"end_ts"` + } `json:"data"` + } + if err := json.Unmarshal(body, &subResp); err != nil { + return "" + } + if len(subResp.Data) == 0 { + return "" + } + + first := subResp.Data[0] + parts := make([]string, 0, 3) + if first.Plan.Title != "" { + parts = append(parts, first.Plan.Title) + } + if first.Plan.ID != "" { + parts = append(parts, first.Plan.ID) + } + if first.EndTS != "" { + parts = append(parts, "end="+first.EndTS) + } + if len(parts) == 0 { + return "" + } + return "Subscription: " + strings.Join(parts, " | ") +} + +func parseSoraInviteSummary(body []byte) string { + var inviteResp struct { + InviteCode string `json:"invite_code"` + RedeemedCount int64 `json:"redeemed_count"` + TotalCount int64 `json:"total_count"` + } + if err := json.Unmarshal(body, &inviteResp); err != nil { + return "" + } + + parts := []string{"Sora2: supported"} + if inviteResp.InviteCode != "" { + parts = append(parts, "invite="+inviteResp.InviteCode) + } + if inviteResp.TotalCount > 0 { + parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount)) + } + return strings.Join(parts, " | ") +} + +func parseSoraRemainingSummary(body []byte) string { + var remainingResp struct { + RateLimitAndCreditBalance struct { + EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"` + RateLimitReached bool `json:"rate_limit_reached"` + AccessResetsInSeconds int64 `json:"access_resets_in_seconds"` + } `json:"rate_limit_and_credit_balance"` + } + if err := json.Unmarshal(body, &remainingResp); err != nil { + return "" + } + info := remainingResp.RateLimitAndCreditBalance + parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)} + if info.RateLimitReached { + parts = append(parts, "rate_limited=true") + } + if info.AccessResetsInSeconds > 0 { + parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds)) + } + return strings.Join(parts, " | ") +} + +func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool { + if s == nil || s.cfg == nil { + return true + } + return !s.cfg.Sora.Client.DisableTLSFingerprint +} + +func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { + return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) +} + +func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string { + return soraerror.FormatCloudflareChallengeMessage(base, headers, body) +} + +func extractCloudflareRayID(headers http.Header, body []byte) string { + return soraerror.ExtractCloudflareRayID(headers, body) +} + +func extractSoraEgressIPHint(headers http.Header) string { + if headers == nil { + return "unknown" + } + candidates := []string{ + "x-openai-public-ip", + "x-envoy-external-address", + "cf-connecting-ip", + "x-forwarded-for", + } + for _, key := range candidates { + if value := strings.TrimSpace(headers.Get(key)); value != "" { + return value + } + } + return "unknown" +} + +func sanitizeProxyURLForLog(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + u, err := url.Parse(raw) + if err != nil { + return "" + } + if u.User != nil { + u.User = nil + } + return u.String() +} + +func endpointPathForLog(endpoint string) string { + parsed, err := url.Parse(strings.TrimSpace(endpoint)) + if err != nil || parsed.Path == "" { + return endpoint + } + return parsed.Path +} + +func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) { + accountID := int64(0) + platform := "" + proxyID := "none" + if account != nil { + accountID = account.ID + platform = account.Platform + if account.ProxyID != nil { + proxyID = fmt.Sprintf("%d", *account.ProxyID) + } + } + cfRay := extractCloudflareRayID(headers, body) + if cfRay == "" { + cfRay = "unknown" + } + log.Printf( + "[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s", + accountID, + platform, + endpoint, + endpointPathForLog(endpoint), + proxyID, + sanitizeProxyURLForLog(proxyURL), + cfRay, + extractSoraEgressIPHint(headers), + ) +} + +func truncateSoraErrorBody(body []byte, max int) string { + return soraerror.TruncateBody(body, max) +} + // testAntigravityAccountConnection tests an Antigravity account's connection // 支持 Claude 和 Gemini 两种协议,使用非流式请求 func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go new file mode 100644 index 00000000..3dfac786 --- /dev/null +++ b/backend/internal/service/account_test_service_sora_test.go @@ -0,0 +1,319 @@ +package service + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type queuedHTTPUpstream struct { + responses []*http.Response + requests []*http.Request + tlsFlags []bool +} + +func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, fmt.Errorf("unexpected Do call") +} + +func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) { + u.requests = append(u.requests, req) + u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint) + if len(u.responses) == 0 { + return nil, fmt.Errorf("no mocked response") + } + resp := u.responses[0] + u.responses = u.responses[1:] + return resp, nil +} + +func newJSONResponse(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func newJSONResponseWithHeader(status int, body, key, value string) *http.Response { + resp := newJSONResponse(status, body) + resp.Header.Set(key, value) + return resp +} + +func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) + return c, rec +} + +func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), + newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`), + newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`), + newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`), + }, + } + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + TLSFingerprint: config.TLSFingerprintConfig{ + Enabled: true, + }, + }, + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + DisableTLSFingerprint: false, + }, + }, + }, + } + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + require.Len(t, upstream.requests, 4) + require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String()) + require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String()) + require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String()) + require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String()) + require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization")) + require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization")) + require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags) + + body := rec.Body.String() + require.Contains(t, body, `"type":"test_start"`) + require.Contains(t, body, "Sora connection OK - Email: demo@example.com") + require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z") + require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50") + require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s") + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"success"`) + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), + newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), + newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`), + newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + require.Len(t, upstream.requests, 4) + body := rec.Body.String() + require.Contains(t, body, "Sora connection OK - User: demo-user") + require.Contains(t, body, "Subscription check returned 403") + require.Contains(t, body, "Sora2 invite check returned 401") + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"partial_success"`) + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponseWithHeader(http.StatusForbidden, `Just a moment...`, "cf-ray", "9cff2d62d83bb98d"), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "Cloudflare challenge") + require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d") + body := rec.Body.String() + require.Contains(t, body, `"type":"error"`) + require.Contains(t, body, "Cloudflare challenge") + require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") +} + +func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponseWithHeader(http.StatusTooManyRequests, `Just a moment...`, "cf-mitigated", "challenge"), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "Cloudflare challenge") + require.Contains(t, err.Error(), "HTTP 429") + body := rec.Body.String() + require.Contains(t, body, "Cloudflare challenge") +} + +func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "token_invalidated") + body := rec.Body.String() + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"failed"`) + require.Contains(t, body, "token_invalidated") + require.NotContains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), + }, + } + svc := &AccountTestService{ + httpUpstream: upstream, + soraTestCooldown: time.Hour, + } + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c1, _ := newSoraTestContext() + err := svc.testSoraAccountConnection(c1, account) + require.NoError(t, err) + + c2, rec2 := newSoraTestContext() + err = svc.testSoraAccountConnection(c2, account) + require.Error(t, err) + require.Contains(t, err.Error(), "测试过于频繁") + body := rec2.Body.String() + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"code":"test_rate_limited"`) + require.Contains(t, body, `"status":"failed"`) + require.NotContains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), + newJSONResponse(http.StatusForbidden, `Just a moment...`), + newJSONResponse(http.StatusForbidden, `Just a moment...`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + body := rec.Body.String() + require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)") + require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)") + require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestSanitizeProxyURLForLog(t *testing.T) { + require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080")) + require.Equal(t, "", sanitizeProxyURLForLog("")) + require.Equal(t, "", sanitizeProxyURLForLog("://invalid")) +} + +func TestExtractSoraEgressIPHint(t *testing.T) { + h := make(http.Header) + h.Set("x-openai-public-ip", "203.0.113.10") + require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h)) + + h2 := make(http.Header) + h2.Set("x-envoy-external-address", "198.51.100.9") + require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2)) + + require.Equal(t, "unknown", extractSoraEgressIPHint(nil)) + require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{})) +} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 304c5781..6dee6c13 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -4,11 +4,14 @@ import ( "context" "fmt" "log" + "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "golang.org/x/sync/errgroup" ) type UsageLogRepository interface { @@ -32,12 +35,13 @@ type UsageLogRepository interface { // Admin dashboard stats GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) - GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) - GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) + GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) + GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) + GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) - GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) - GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) + GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) + GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) // User dashboard stats GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) @@ -61,6 +65,10 @@ type UsageLogRepository interface { GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) } +type accountWindowStatsBatchReader interface { + GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) +} + // apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) type apiUsageCache struct { response *ClaudeUsageResponse @@ -217,12 +225,20 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U } if account.Platform == PlatformGemini { - return s.getGeminiUsage(ctx, account) + usage, err := s.getGeminiUsage(ctx, account) + if err == nil { + s.tryClearRecoverableAccountError(ctx, account) + } + return usage, err } // Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度 if account.Platform == PlatformAntigravity { - return s.getAntigravityUsage(ctx, account) + usage, err := s.getAntigravityUsage(ctx, account) + if err == nil { + s.tryClearRecoverableAccountError(ctx, account) + } + return usage, err } // 只有oauth类型账号可以通过API获取usage(有profile scope) @@ -256,6 +272,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U // 4. 添加窗口统计(有独立缓存,1 分钟) s.addWindowStats(ctx, account, usage) + s.tryClearRecoverableAccountError(ctx, account) return usage, nil } @@ -287,7 +304,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou } dayStart := geminiDailyWindowStart(now) - stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil) + stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return nil, fmt.Errorf("get gemini usage stats failed: %w", err) } @@ -309,7 +326,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou // Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m) minuteStart := now.Truncate(time.Minute) minuteResetAt := minuteStart.Add(time.Minute) - minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil) + minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err) } @@ -430,6 +447,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64 }, nil } +// GetTodayStatsBatch 批量获取账号今日统计,优先走批量 SQL,失败时回退单账号查询。 +func (s *AccountUsageService) GetTodayStatsBatch(ctx context.Context, accountIDs []int64) (map[int64]*WindowStats, error) { + uniqueIDs := make([]int64, 0, len(accountIDs)) + seen := make(map[int64]struct{}, len(accountIDs)) + for _, accountID := range accountIDs { + if accountID <= 0 { + continue + } + if _, exists := seen[accountID]; exists { + continue + } + seen[accountID] = struct{}{} + uniqueIDs = append(uniqueIDs, accountID) + } + + result := make(map[int64]*WindowStats, len(uniqueIDs)) + if len(uniqueIDs) == 0 { + return result, nil + } + + startTime := timezone.Today() + if batchReader, ok := s.usageLogRepo.(accountWindowStatsBatchReader); ok { + statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, uniqueIDs, startTime) + if err == nil { + for _, accountID := range uniqueIDs { + result[accountID] = windowStatsFromAccountStats(statsByAccount[accountID]) + } + return result, nil + } + } + + var mu sync.Mutex + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(8) + + for _, accountID := range uniqueIDs { + id := accountID + g.Go(func() error { + stats, err := s.usageLogRepo.GetAccountWindowStats(gctx, id, startTime) + if err != nil { + return nil + } + mu.Lock() + result[id] = windowStatsFromAccountStats(stats) + mu.Unlock() + return nil + }) + } + + _ = g.Wait() + + for _, accountID := range uniqueIDs { + if _, ok := result[accountID]; !ok { + result[accountID] = &WindowStats{} + } + } + return result, nil +} + +func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats { + if stats == nil { + return &WindowStats{} + } + return &WindowStats{ + Requests: stats.Requests, + Tokens: stats.Tokens, + Cost: stats.Cost, + StandardCost: stats.StandardCost, + UserCost: stats.UserCost, + } +} + func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime) if err != nil { @@ -486,6 +575,32 @@ func parseTime(s string) (time.Time, error) { return time.Time{}, fmt.Errorf("unable to parse time: %s", s) } +func (s *AccountUsageService) tryClearRecoverableAccountError(ctx context.Context, account *Account) { + if account == nil || account.Status != StatusError { + return + } + + msg := strings.ToLower(strings.TrimSpace(account.ErrorMessage)) + if msg == "" { + return + } + + if !strings.Contains(msg, "token refresh failed") && + !strings.Contains(msg, "invalid_client") && + !strings.Contains(msg, "missing_project_id") && + !strings.Contains(msg, "unauthenticated") { + return + } + + if err := s.accountRepo.ClearError(ctx, account.ID); err != nil { + log.Printf("[usage] failed to clear recoverable account error for account %d: %v", account.ID, err) + return + } + + account.Status = StatusActive + account.ErrorMessage = "" +} + // buildUsageInfo 构建UsageInfo func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedAt *time.Time) *UsageInfo { info := &UsageInfo{ diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 90e5b573..7782f948 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -267,3 +267,119 @@ func TestAccountGetMappedModel(t *testing.T) { }) } } + +func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) { + account := &Account{ + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3-pro-high": "gemini-3.1-pro-high", + }, + }, + } + + mapping := account.GetModelMapping() + if mapping["gemini-3-flash"] != "gemini-3-flash" { + t.Fatalf("expected gemini-3-flash passthrough to be auto-filled, got: %q", mapping["gemini-3-flash"]) + } + if mapping["gemini-3.1-pro-high"] != "gemini-3.1-pro-high" { + t.Fatalf("expected gemini-3.1-pro-high passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-high"]) + } + if mapping["gemini-3.1-pro-low"] != "gemini-3.1-pro-low" { + t.Fatalf("expected gemini-3.1-pro-low passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-low"]) + } +} + +func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T) { + account := &Account{ + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3*": "gemini-3.1-pro-high", + }, + }, + } + + mapping := account.GetModelMapping() + if _, exists := mapping["gemini-3-flash"]; exists { + t.Fatalf("did not expect explicit gemini-3-flash passthrough when wildcard already exists") + } + if _, exists := mapping["gemini-3.1-pro-high"]; exists { + t.Fatalf("did not expect explicit gemini-3.1-pro-high passthrough when wildcard already exists") + } + if _, exists := mapping["gemini-3.1-pro-low"]; exists { + t.Fatalf("did not expect explicit gemini-3.1-pro-low passthrough when wildcard already exists") + } + if mapped := account.GetMappedModel("gemini-3-flash"); mapped != "gemini-3.1-pro-high" { + t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped) + } +} + +func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) { + account := &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "upstream-a", + }, + }, + } + + first := account.GetModelMapping() + if first["claude-3-5-sonnet"] != "upstream-a" { + t.Fatalf("unexpected first mapping: %v", first) + } + + account.Credentials = map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "upstream-b", + }, + } + second := account.GetModelMapping() + if second["claude-3-5-sonnet"] != "upstream-b" { + t.Fatalf("expected cache invalidated after credentials replace, got: %v", second) + } +} + +func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) { + rawMapping := map[string]any{ + "claude-sonnet": "sonnet-a", + } + account := &Account{ + Credentials: map[string]any{ + "model_mapping": rawMapping, + }, + } + + first := account.GetModelMapping() + if len(first) != 1 { + t.Fatalf("unexpected first mapping length: %d", len(first)) + } + + rawMapping["claude-opus"] = "opus-b" + second := account.GetModelMapping() + if second["claude-opus"] != "opus-b" { + t.Fatalf("expected cache invalidated after mapping len change, got: %v", second) + } +} + +func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) { + rawMapping := map[string]any{ + "claude-sonnet": "sonnet-a", + } + account := &Account{ + Credentials: map[string]any{ + "model_mapping": rawMapping, + }, + } + + first := account.GetModelMapping() + if first["claude-sonnet"] != "sonnet-a" { + t.Fatalf("unexpected first mapping: %v", first) + } + + rawMapping["claude-sonnet"] = "sonnet-b" + second := account.GetModelMapping() + if second["claude-sonnet"] != "sonnet-b" { + t.Fatalf("expected cache invalidated after in-place value change, got: %v", second) + } +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 59d7062b..7e6982d3 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -4,11 +4,17 @@ import ( "context" "errors" "fmt" - "log" + "io" + "net/http" "strings" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" ) // AdminService interface defines admin management operations @@ -36,9 +42,13 @@ type AdminService interface { UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) DeleteGroup(ctx context.Context, id int64) error GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) + UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error + + // API Key management (admin) + AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) // Account management - ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) + ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) GetAccount(ctx context.Context, id int64) (*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) @@ -49,6 +59,7 @@ type AdminService interface { SetAccountError(ctx context.Context, id int64, errorMsg string) error SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) + CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error // Proxy management ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) @@ -64,6 +75,7 @@ type AdminService interface { GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) + CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) // Redeem code management ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) @@ -76,13 +88,14 @@ type AdminService interface { // CreateUserInput represents input for creating a new user via admin operations. type CreateUserInput struct { - Email string - Password string - Username string - Notes string - Balance float64 - Concurrency int - AllowedGroups []int64 + Email string + Password string + Username string + Notes string + Balance float64 + Concurrency int + AllowedGroups []int64 + SoraStorageQuotaBytes int64 } type UpdateUserInput struct { @@ -96,7 +109,8 @@ type UpdateUserInput struct { AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" // GroupRates 用户专属分组倍率配置 // map[groupID]*rate,nil 表示删除该分组的专属倍率 - GroupRates map[int64]*float64 + GroupRates map[int64]*float64 + SoraStorageQuotaBytes *int64 } type CreateGroupInput struct { @@ -110,11 +124,16 @@ type CreateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + // Sora 按次计费配置 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -123,6 +142,8 @@ type CreateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string + // Sora 存储配额 + SoraStorageQuotaBytes int64 // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -139,11 +160,16 @@ type UpdateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + // Sora 按次计费配置 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -152,6 +178,8 @@ type UpdateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string + // Sora 存储配额 + SoraStorageQuotaBytes *int64 // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -219,6 +247,14 @@ type BulkUpdateAccountResult struct { Error string `json:"error,omitempty"` } +// AdminUpdateAPIKeyGroupIDResult is the result of AdminUpdateAPIKeyGroupID. +type AdminUpdateAPIKeyGroupIDResult struct { + APIKey *APIKey + AutoGrantedGroupAccess bool // true if a new exclusive group permission was auto-added + GrantedGroupID *int64 // the group ID that was auto-granted + GrantedGroupName string // the group name that was auto-granted +} + // BulkUpdateAccountsResult is the aggregated response for bulk updates. type BulkUpdateAccountsResult struct { Success int `json:"success"` @@ -277,6 +313,32 @@ type ProxyTestResult struct { CountryCode string `json:"country_code,omitempty"` } +type ProxyQualityCheckResult struct { + ProxyID int64 `json:"proxy_id"` + Score int `json:"score"` + Grade string `json:"grade"` + Summary string `json:"summary"` + ExitIP string `json:"exit_ip,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` + BaseLatencyMs int64 `json:"base_latency_ms,omitempty"` + PassedCount int `json:"passed_count"` + WarnCount int `json:"warn_count"` + FailedCount int `json:"failed_count"` + ChallengeCount int `json:"challenge_count"` + CheckedAt int64 `json:"checked_at"` + Items []ProxyQualityCheckItem `json:"items"` +} + +type ProxyQualityCheckItem struct { + Target string `json:"target"` + Status string `json:"status"` // pass/warn/fail/challenge + HTTPStatus int `json:"http_status,omitempty"` + LatencyMs int64 `json:"latency_ms,omitempty"` + Message string `json:"message,omitempty"` + CFRay string `json:"cf_ray,omitempty"` +} + // ProxyExitInfo represents proxy exit information from ip-api.com type ProxyExitInfo struct { IP string @@ -291,11 +353,64 @@ type ProxyExitInfoProber interface { ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) } +type proxyQualityTarget struct { + Target string + URL string + Method string + AllowedStatuses map[int]struct{} +} + +var proxyQualityTargets = []proxyQualityTarget{ + { + Target: "openai", + URL: "https://api.openai.com/v1/models", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + }, + { + Target: "anthropic", + URL: "https://api.anthropic.com/v1/messages", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + http.StatusMethodNotAllowed: {}, + http.StatusNotFound: {}, + http.StatusBadRequest: {}, + }, + }, + { + Target: "gemini", + URL: "https://generativelanguage.googleapis.com/$discovery/rest?version=v1beta", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusOK: {}, + }, + }, + { + Target: "sora", + URL: "https://sora.chatgpt.com/backend/me", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + }, +} + +const ( + proxyQualityRequestTimeout = 15 * time.Second + proxyQualityResponseHeaderTimeout = 10 * time.Second + proxyQualityMaxBodyBytes = int64(8 * 1024) + proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" +) + // adminServiceImpl implements AdminService type adminServiceImpl struct { userRepo UserRepository groupRepo GroupRepository accountRepo AccountRepository + soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository @@ -304,6 +419,17 @@ type adminServiceImpl struct { proxyProber ProxyExitInfoProber proxyLatencyCache ProxyLatencyCache authCacheInvalidator APIKeyAuthCacheInvalidator + entClient *dbent.Client // 用于开启数据库事务 + settingService *SettingService + defaultSubAssigner DefaultSubscriptionAssigner +} + +type userGroupRateBatchReader interface { + GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) +} + +type groupExistenceBatchReader interface { + ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) } // NewAdminService creates a new AdminService @@ -311,6 +437,7 @@ func NewAdminService( userRepo UserRepository, groupRepo GroupRepository, accountRepo AccountRepository, + soraAccountRepo SoraAccountRepository, proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, @@ -319,11 +446,15 @@ func NewAdminService( proxyProber ProxyExitInfoProber, proxyLatencyCache ProxyLatencyCache, authCacheInvalidator APIKeyAuthCacheInvalidator, + entClient *dbent.Client, + settingService *SettingService, + defaultSubAssigner DefaultSubscriptionAssigner, ) AdminService { return &adminServiceImpl{ userRepo: userRepo, groupRepo: groupRepo, accountRepo: accountRepo, + soraAccountRepo: soraAccountRepo, proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, @@ -332,6 +463,9 @@ func NewAdminService( proxyProber: proxyProber, proxyLatencyCache: proxyLatencyCache, authCacheInvalidator: authCacheInvalidator, + entClient: entClient, + settingService: settingService, + defaultSubAssigner: defaultSubAssigner, } } @@ -344,18 +478,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi } // 批量加载用户专属分组倍率 if s.userGroupRateRepo != nil && len(users) > 0 { - for i := range users { - rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) - if err != nil { - log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err) - continue + if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok { + userIDs := make([]int64, 0, len(users)) + for i := range users { + userIDs = append(userIDs, users[i].ID) } - users[i].GroupRates = rates + ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err) + s.loadUserGroupRatesOneByOne(ctx, users) + } else { + for i := range users { + if rates, ok := ratesByUser[users[i].ID]; ok { + users[i].GroupRates = rates + } + } + } + } else { + s.loadUserGroupRatesOneByOne(ctx, users) } } return users, result.Total, nil } +func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) { + if s.userGroupRateRepo == nil { + return + } + for i := range users { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err) + continue + } + users[i].GroupRates = rates + } +} + func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { @@ -365,7 +524,7 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) if s.userGroupRateRepo != nil { rates, err := s.userGroupRateRepo.GetByUserID(ctx, id) if err != nil { - log.Printf("failed to load user group rates: user_id=%d err=%v", id, err) + logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", id, err) } else { user.GroupRates = rates } @@ -375,14 +534,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { user := &User{ - Email: input.Email, - Username: input.Username, - Notes: input.Notes, - Role: RoleUser, // Always create as regular user, never admin - Balance: input.Balance, - Concurrency: input.Concurrency, - Status: StatusActive, - AllowedGroups: input.AllowedGroups, + Email: input.Email, + Username: input.Username, + Notes: input.Notes, + Role: RoleUser, // Always create as regular user, never admin + Balance: input.Balance, + Concurrency: input.Concurrency, + Status: StatusActive, + AllowedGroups: input.AllowedGroups, + SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, } if err := user.SetPassword(input.Password); err != nil { return nil, err @@ -390,9 +550,27 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu if err := s.userRepo.Create(ctx, user); err != nil { return nil, err } + s.assignDefaultSubscriptions(ctx, user.ID) return user, nil } +func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) { + if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { + return + } + items := s.settingService.GetDefaultSubscriptions(ctx) + for _, item := range items { + if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + Notes: "auto assigned by default user subscriptions setting", + }); err != nil { + logger.LegacyPrintf("service.admin", "failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) + } + } +} + func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { @@ -436,6 +614,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda user.AllowedGroups = *input.AllowedGroups } + if input.SoraStorageQuotaBytes != nil { + user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes + } + if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } @@ -443,7 +625,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda // 同步用户专属分组倍率 if input.GroupRates != nil && s.userGroupRateRepo != nil { if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil { - log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err) + logger.LegacyPrintf("service.admin", "failed to sync user group rates: user_id=%d err=%v", user.ID, err) } } @@ -457,7 +639,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda if concurrencyDiff != 0 { code, err := GenerateRedeemCode() if err != nil { - log.Printf("failed to generate adjustment redeem code: %v", err) + logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err) return user, nil } adjustmentRecord := &RedeemCode{ @@ -470,7 +652,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda now := time.Now() adjustmentRecord.UsedAt = &now if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { - log.Printf("failed to create concurrency adjustment redeem code: %v", err) + logger.LegacyPrintf("service.admin", "failed to create concurrency adjustment redeem code: %v", err) } } @@ -487,7 +669,7 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { return errors.New("cannot delete admin user") } if err := s.userRepo.Delete(ctx, id); err != nil { - log.Printf("delete user failed: user_id=%d err=%v", id, err) + logger.LegacyPrintf("service.admin", "delete user failed: user_id=%d err=%v", id, err) return err } if s.authCacheInvalidator != nil { @@ -530,7 +712,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil { - log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) + logger.LegacyPrintf("service.admin", "invalidate user balance cache failed: user_id=%d err=%v", userID, err) } }() } @@ -538,7 +720,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, if balanceDiff != 0 { code, err := GenerateRedeemCode() if err != nil { - log.Printf("failed to generate adjustment redeem code: %v", err) + logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err) return user, nil } @@ -554,7 +736,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, adjustmentRecord.UsedAt = &now if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { - log.Printf("failed to create balance adjustment redeem code: %v", err) + logger.LegacyPrintf("service.admin", "failed to create balance adjustment redeem code: %v", err) } } @@ -638,6 +820,10 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice4K := normalizePrice(input.ImagePrice4K) + soraImagePrice360 := normalizePrice(input.SoraImagePrice360) + soraImagePrice540 := normalizePrice(input.SoraImagePrice540) + soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest) + soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD) // 校验降级分组 if input.FallbackGroupID != nil { @@ -708,12 +894,17 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ImagePrice1K: imagePrice1K, ImagePrice2K: imagePrice2K, ImagePrice4K: imagePrice4K, + SoraImagePrice360: soraImagePrice360, + SoraImagePrice540: soraImagePrice540, + SoraVideoPricePerRequest: soraVideoPrice, + SoraVideoPricePerRequestHD: soraVideoPriceHD, ClaudeCodeOnly: input.ClaudeCodeOnly, FallbackGroupID: input.FallbackGroupID, FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, ModelRouting: input.ModelRouting, MCPXMLInject: mcpXMLInject, SupportedModelScopes: input.SupportedModelScopes, + SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -864,6 +1055,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.ImagePrice4K != nil { group.ImagePrice4K = normalizePrice(input.ImagePrice4K) } + if input.SoraImagePrice360 != nil { + group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360) + } + if input.SoraImagePrice540 != nil { + group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540) + } + if input.SoraVideoPricePerRequest != nil { + group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest) + } + if input.SoraVideoPricePerRequestHD != nil { + group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD) + } + if input.SoraStorageQuotaBytes != nil { + group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes + } // Claude Code 客户端限制 if input.ClaudeCodeOnly != nil { @@ -992,7 +1198,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { defer cancel() for _, userID := range affectedUserIDs { if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil { - log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) + logger.LegacyPrintf("service.admin", "invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) } } }() @@ -1015,10 +1221,111 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p return keys, result.Total, nil } +func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { + return s.groupRepo.UpdateSortOrders(ctx, updates) +} + +// AdminUpdateAPIKeyGroupID 管理员修改 API Key 分组绑定 +// groupID: nil=不修改, 指向0=解绑, 指向正整数=绑定到目标分组 +func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) { + apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID) + if err != nil { + return nil, err + } + + if groupID == nil { + // nil 表示不修改,直接返回 + return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil + } + + if *groupID < 0 { + return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative") + } + + result := &AdminUpdateAPIKeyGroupIDResult{} + + if *groupID == 0 { + // 0 表示解绑分组(不修改 user_allowed_groups,避免影响用户其他 Key) + apiKey.GroupID = nil + apiKey.Group = nil + } else { + // 验证目标分组存在且状态为 active + group, err := s.groupRepo.GetByID(ctx, *groupID) + if err != nil { + return nil, err + } + if group.Status != StatusActive { + return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") + } + // 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程 + if group.IsSubscriptionType() { + return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow") + } + + gid := *groupID + apiKey.GroupID = &gid + apiKey.Group = group + + // 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性 + if group.IsExclusive { + opCtx := ctx + var tx *dbent.Tx + if s.entClient == nil { + logger.LegacyPrintf("service.admin", "Warning: entClient is nil, skipping transaction protection for exclusive group binding") + } else { + var txErr error + tx, txErr = s.entClient.Tx(ctx) + if txErr != nil { + return nil, fmt.Errorf("begin transaction: %w", txErr) + } + defer func() { _ = tx.Rollback() }() + opCtx = dbent.NewTxContext(ctx, tx) + } + + if addErr := s.userRepo.AddGroupToAllowedGroups(opCtx, apiKey.UserID, gid); addErr != nil { + return nil, fmt.Errorf("add group to user allowed groups: %w", addErr) + } + if err := s.apiKeyRepo.Update(opCtx, apiKey); err != nil { + return nil, fmt.Errorf("update api key: %w", err) + } + if tx != nil { + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit transaction: %w", err) + } + } + + result.AutoGrantedGroupAccess = true + result.GrantedGroupID = &gid + result.GrantedGroupName = group.Name + + // 失效认证缓存(在事务提交后执行) + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + + result.APIKey = apiKey + return result, nil + } + } + + // 非专属分组 / 解绑:无需事务,单步更新即可 + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { + return nil, fmt.Errorf("update api key: %w", err) + } + + // 失效认证缓存 + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + + result.APIKey = apiKey + return result, nil +} + // Account management implementations -func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) { +func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) + accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID) if err != nil { return nil, 0, err } @@ -1066,6 +1373,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } + // Sora apikey 账号的 base_url 必填校验 + if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey { + baseURL, _ := input.Credentials["base_url"].(string) + baseURL = strings.TrimSpace(baseURL) + if baseURL == "" { + return nil, errors.New("sora apikey 账号必须设置 base_url") + } + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") + } + } + account := &Account{ Name: input.Name, Notes: normalizeAccountNotes(input.Notes), @@ -1098,6 +1417,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou return nil, err } + // 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录 + if account.Platform == PlatformSora && s.soraAccountRepo != nil { + soraUpdates := map[string]any{ + "access_token": account.GetCredential("access_token"), + "refresh_token": account.GetCredential("refresh_token"), + } + if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil { + // 只记录警告日志,不阻塞账号创建 + logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err) + } + } + // 绑定分组 if len(groupIDs) > 0 { if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { @@ -1167,12 +1498,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.AutoPauseOnExpired = *input.AutoPauseOnExpired } + // Sora apikey 账号的 base_url 必填校验 + if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey { + baseURL, _ := account.Credentials["base_url"].(string) + baseURL = strings.TrimSpace(baseURL) + if baseURL == "" { + return nil, errors.New("sora apikey 账号必须设置 base_url") + } + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") + } + } + // 先验证分组是否存在(在任何写操作之前) if input.GroupIDs != nil { - for _, groupID := range *input.GroupIDs { - if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { - return nil, fmt.Errorf("get group: %w", err) - } + if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { + return nil, err } // 检查混合渠道风险(除非用户已确认) @@ -1195,7 +1536,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U } // 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象) - return s.accountRepo.GetByID(ctx, id) + updated, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + return updated, nil } // BulkUpdateAccounts updates multiple accounts in one request. @@ -1210,10 +1555,17 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp if len(input.AccountIDs) == 0 { return result, nil } + if input.GroupIDs != nil { + if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { + return nil, err + } + } - // Preload account platforms for mixed channel risk checks if group bindings are requested. + needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck + + // 预加载账号平台信息(混合渠道检查需要)。 platformByID := map[int64]string{} - if input.GroupIDs != nil && !input.SkipMixedChannelCheck { + if needMixedChannelCheck { accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) if err != nil { return nil, err @@ -1225,6 +1577,19 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp } } + // 预检查混合渠道风险:在任何写操作之前,若发现风险立即返回错误。 + if needMixedChannelCheck { + for _, accountID := range input.AccountIDs { + platform := platformByID[accountID] + if platform == "" { + continue + } + if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { + return nil, err + } + } + } + if input.RateMultiplier != nil { if *input.RateMultiplier < 0 { return nil, errors.New("rate_multiplier must be >= 0") @@ -1268,31 +1633,6 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp entry := BulkUpdateAccountResult{AccountID: accountID} if input.GroupIDs != nil { - // 检查混合渠道风险(除非用户已确认) - if !input.SkipMixedChannelCheck { - platform := platformByID[accountID] - if platform == "" { - account, err := s.accountRepo.GetByID(ctx, accountID) - if err != nil { - entry.Success = false - entry.Error = err.Error() - result.Failed++ - result.FailedIDs = append(result.FailedIDs, accountID) - result.Results = append(result.Results, entry) - continue - } - platform = account.Platform - } - if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { - entry.Success = false - entry.Error = err.Error() - result.Failed++ - result.FailedIDs = append(result.FailedIDs, accountID) - result.Results = append(result.Results, entry) - continue - } - } - if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil { entry.Success = false entry.Error = err.Error() @@ -1313,7 +1653,10 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp } func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { - return s.accountRepo.Delete(ctx, id) + if err := s.accountRepo.Delete(ctx, id); err != nil { + return err + } + return nil } func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { @@ -1346,7 +1689,11 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { return nil, err } - return s.accountRepo.GetByID(ctx, id) + updated, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + return updated, nil } // Proxy management implementations @@ -1624,6 +1971,269 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR }, nil } +func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + result := &ProxyQualityCheckResult{ + ProxyID: id, + Score: 100, + Grade: "A", + CheckedAt: time.Now().Unix(), + Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1), + } + + proxyURL := proxy.URL() + if s.proxyProber == nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "fail", + Message: "代理探测服务未配置", + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, nil) + return result, nil + } + + exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) + if err != nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "fail", + LatencyMs: latencyMs, + Message: err.Error(), + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, nil) + return result, nil + } + + result.ExitIP = exitInfo.IP + result.Country = exitInfo.Country + result.CountryCode = exitInfo.CountryCode + result.BaseLatencyMs = latencyMs + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "pass", + LatencyMs: latencyMs, + Message: "代理出口连通正常", + }) + result.PassedCount++ + + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: proxyQualityRequestTimeout, + ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, + }) + if err != nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "http_client", + Status: "fail", + Message: fmt.Sprintf("创建检测客户端失败: %v", err), + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + return result, nil + } + + for _, target := range proxyQualityTargets { + item := runProxyQualityTarget(ctx, client, target) + result.Items = append(result.Items, item) + switch item.Status { + case "pass": + result.PassedCount++ + case "warn": + result.WarnCount++ + case "challenge": + result.ChallengeCount++ + default: + result.FailedCount++ + } + } + + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + return result, nil +} + +func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem { + item := ProxyQualityCheckItem{ + Target: target.Target, + } + + req, err := http.NewRequestWithContext(ctx, target.Method, target.URL, nil) + if err != nil { + item.Status = "fail" + item.Message = fmt.Sprintf("构建请求失败: %v", err) + return item + } + req.Header.Set("Accept", "application/json,text/html,*/*") + req.Header.Set("User-Agent", proxyQualityClientUserAgent) + + start := time.Now() + resp, err := client.Do(req) + if err != nil { + item.Status = "fail" + item.LatencyMs = time.Since(start).Milliseconds() + item.Message = fmt.Sprintf("请求失败: %v", err) + return item + } + defer func() { _ = resp.Body.Close() }() + item.LatencyMs = time.Since(start).Milliseconds() + item.HTTPStatus = resp.StatusCode + + body, readErr := io.ReadAll(io.LimitReader(resp.Body, proxyQualityMaxBodyBytes+1)) + if readErr != nil { + item.Status = "fail" + item.Message = fmt.Sprintf("读取响应失败: %v", readErr) + return item + } + if int64(len(body)) > proxyQualityMaxBodyBytes { + body = body[:proxyQualityMaxBodyBytes] + } + + if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { + item.Status = "challenge" + item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body) + item.Message = "Sora 命中 Cloudflare challenge" + return item + } + + if _, ok := target.AllowedStatuses[resp.StatusCode]; ok { + if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices { + item.Status = "pass" + item.Message = fmt.Sprintf("HTTP %d", resp.StatusCode) + } else { + item.Status = "warn" + item.Message = fmt.Sprintf("HTTP %d(目标可达,但鉴权或方法受限)", resp.StatusCode) + } + return item + } + + if resp.StatusCode == http.StatusTooManyRequests { + item.Status = "warn" + item.Message = "目标返回 429,可能存在频控" + return item + } + + item.Status = "fail" + item.Message = fmt.Sprintf("非预期状态码: %d", resp.StatusCode) + return item +} + +func finalizeProxyQualityResult(result *ProxyQualityCheckResult) { + if result == nil { + return + } + score := 100 - result.WarnCount*10 - result.FailedCount*22 - result.ChallengeCount*30 + if score < 0 { + score = 0 + } + result.Score = score + result.Grade = proxyQualityGrade(score) + result.Summary = fmt.Sprintf( + "通过 %d 项,告警 %d 项,失败 %d 项,挑战 %d 项", + result.PassedCount, + result.WarnCount, + result.FailedCount, + result.ChallengeCount, + ) +} + +func proxyQualityGrade(score int) string { + switch { + case score >= 90: + return "A" + case score >= 75: + return "B" + case score >= 60: + return "C" + case score >= 40: + return "D" + default: + return "F" + } +} + +func proxyQualityOverallStatus(result *ProxyQualityCheckResult) string { + if result == nil { + return "" + } + if result.ChallengeCount > 0 { + return "challenge" + } + if result.FailedCount > 0 { + return "failed" + } + if result.WarnCount > 0 { + return "warn" + } + if result.PassedCount > 0 { + return "healthy" + } + return "failed" +} + +func proxyQualityFirstCFRay(result *ProxyQualityCheckResult) string { + if result == nil { + return "" + } + for _, item := range result.Items { + if item.CFRay != "" { + return item.CFRay + } + } + return "" +} + +func proxyQualityBaseConnectivityPass(result *ProxyQualityCheckResult) bool { + if result == nil { + return false + } + for _, item := range result.Items { + if item.Target == "base_connectivity" { + return item.Status == "pass" + } + } + return false +} + +func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) { + if result == nil { + return + } + score := result.Score + checkedAt := result.CheckedAt + info := &ProxyLatencyInfo{ + Success: proxyQualityBaseConnectivityPass(result), + Message: result.Summary, + QualityStatus: proxyQualityOverallStatus(result), + QualityScore: &score, + QualityGrade: result.Grade, + QualitySummary: result.Summary, + QualityCheckedAt: &checkedAt, + QualityCFRay: proxyQualityFirstCFRay(result), + UpdatedAt: time.Now(), + } + if result.BaseLatencyMs > 0 { + latency := result.BaseLatencyMs + info.LatencyMs = &latency + } + if exitInfo != nil { + info.IPAddress = exitInfo.IP + info.Country = exitInfo.Country + info.CountryCode = exitInfo.CountryCode + info.Region = exitInfo.Region + info.City = exitInfo.City + } + s.saveProxyLatency(ctx, proxyID, info) +} + func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) { if s.proxyProber == nil || proxy == nil { return @@ -1701,6 +2311,40 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc return nil } +func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + if s.groupRepo == nil { + return errors.New("group repository not configured") + } + + if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok { + existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs) + if err != nil { + return fmt.Errorf("check groups exists: %w", err) + } + for _, groupID := range groupIDs { + if groupID <= 0 || !existsByID[groupID] { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + } + return nil + } + + for _, groupID := range groupIDs { + if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { + return fmt.Errorf("get group: %w", err) + } + } + return nil +} + +// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform. +func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { + return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs) +} + func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) { if s.proxyLatencyCache == nil || len(proxies) == 0 { return @@ -1713,7 +2357,7 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids) if err != nil { - log.Printf("Warning: load proxy latency cache failed: %v", err) + logger.LegacyPrintf("service.admin", "Warning: load proxy latency cache failed: %v", err) return } @@ -1734,6 +2378,11 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro proxies[i].CountryCode = info.CountryCode proxies[i].Region = info.Region proxies[i].City = info.City + proxies[i].QualityStatus = info.QualityStatus + proxies[i].QualityScore = info.QualityScore + proxies[i].QualityGrade = info.QualityGrade + proxies[i].QualitySummary = info.QualitySummary + proxies[i].QualityChecked = info.QualityCheckedAt } } @@ -1741,8 +2390,28 @@ func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, if s.proxyLatencyCache == nil || info == nil { return } - if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil { - log.Printf("Warning: store proxy latency cache failed: %v", err) + + merged := *info + if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil { + if existing := latencies[proxyID]; existing != nil { + if merged.QualityCheckedAt == nil && + merged.QualityScore == nil && + merged.QualityGrade == "" && + merged.QualityStatus == "" && + merged.QualitySummary == "" && + merged.QualityCFRay == "" { + merged.QualityStatus = existing.QualityStatus + merged.QualityScore = existing.QualityScore + merged.QualityGrade = existing.QualityGrade + merged.QualitySummary = existing.QualitySummary + merged.QualityCheckedAt = existing.QualityCheckedAt + merged.QualityCFRay = existing.QualityCFRay + } + } + } + + if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil { + logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err) } } diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go new file mode 100644 index 00000000..9210a786 --- /dev/null +++ b/backend/internal/service/admin_service_apikey_test.go @@ -0,0 +1,420 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Stubs +// --------------------------------------------------------------------------- + +// userRepoStubForGroupUpdate implements UserRepository for AdminUpdateAPIKeyGroupID tests. +type userRepoStubForGroupUpdate struct { + addGroupErr error + addGroupCalled bool + addedUserID int64 + addedGroupID int64 +} + +func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, userID int64, groupID int64) error { + s.addGroupCalled = true + s.addedUserID = userID + s.addedGroupID = groupID + return s.addGroupErr +} + +func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } + +// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests. +type apiKeyRepoStubForGroupUpdate struct { + key *APIKey + getErr error + updateErr error + updated *APIKey // captures what was passed to Update +} + +func (s *apiKeyRepoStubForGroupUpdate) GetByID(_ context.Context, _ int64) (*APIKey, error) { + if s.getErr != nil { + return nil, s.getErr + } + clone := *s.key + return &clone, nil +} +func (s *apiKeyRepoStubForGroupUpdate) Update(_ context.Context, key *APIKey) error { + if s.updateErr != nil { + return s.updateErr + } + clone := *key + s.updated = &clone + return nil +} + +// Unused methods – panic on unexpected call. +func (s *apiKeyRepoStubForGroupUpdate) Create(context.Context, *APIKey) error { panic("unexpected") } +func (s *apiKeyRepoStubForGroupUpdate) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) GetByKey(context.Context, string) (*APIKey, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string) (*APIKey, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) CountByUserID(context.Context, int64) (int64, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) ExistsByKey(context.Context, string) (bool, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) ListKeysByUserID(context.Context, int64) ([]string, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) ListKeysByGroupID(context.Context, int64) ([]string, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error { + panic("unexpected") +} + +// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests. +type groupRepoStubForGroupUpdate struct { + group *Group + getErr error + lastGetByIDArg int64 +} + +func (s *groupRepoStubForGroupUpdate) GetByID(_ context.Context, id int64) (*Group, error) { + s.lastGetByIDArg = id + if s.getErr != nil { + return nil, s.getErr + } + clone := *s.group + return &clone, nil +} + +// Unused methods – panic on unexpected call. +func (s *groupRepoStubForGroupUpdate) Create(context.Context, *Group) error { panic("unexpected") } +func (s *groupRepoStubForGroupUpdate) GetByIDLite(context.Context, int64) (*Group, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) Update(context.Context, *Group) error { panic("unexpected") } +func (s *groupRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *groupRepoStubForGroupUpdate) DeleteCascade(context.Context, int64) ([]int64, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) ListActive(context.Context) ([]Group, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, string) ([]Group, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) BindAccountsToGroup(context.Context, int64, []int64) error { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error { + panic("unexpected") +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +func TestAdminService_AdminUpdateAPIKeyGroupID_KeyNotFound(t *testing.T) { + repo := &apiKeyRepoStubForGroupUpdate{getErr: ErrAPIKeyNotFound} + svc := &adminServiceImpl{apiKeyRepo: repo} + + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 999, int64Ptr(1)) + require.ErrorIs(t, err, ErrAPIKeyNotFound) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_NilGroupID_NoOp(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5)} + repo := &apiKeyRepoStubForGroupUpdate{key: existing} + svc := &adminServiceImpl{apiKeyRepo: repo} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, nil) + require.NoError(t, err) + require.Equal(t, int64(1), got.APIKey.ID) + // Update should NOT have been called (updated stays nil) + require.Nil(t, repo.updated) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5), Group: &Group{ID: 5, Name: "Old"}} + repo := &apiKeyRepoStubForGroupUpdate{key: existing} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: repo, authCacheInvalidator: cache} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0)) + require.NoError(t, err) + require.Nil(t, got.APIKey.GroupID, "group_id should be nil after unbind") + require.Nil(t, got.APIKey.Group, "group object should be nil after unbind") + require.NotNil(t, repo.updated, "Update should have been called") + require.Nil(t, repo.updated.GroupID) + require.Equal(t, []string{"sk-test"}, cache.keys, "cache should be invalidated") +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_BindActiveGroup(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.NoError(t, err) + require.NotNil(t, got.APIKey.GroupID) + require.Equal(t, int64(10), *got.APIKey.GroupID) + require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID) + require.Equal(t, []string{"sk-test"}, cache.keys) + // M3: verify correct group ID was passed to repo + require.Equal(t, int64(10), groupRepo.lastGetByIDArg) + // C1 fix: verify Group object is populated + require.NotNil(t, got.APIKey.Group) + require.Equal(t, "Pro", got.APIKey.Group.Name) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_SameGroup_Idempotent(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Pro"}} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.NoError(t, err) + require.NotNil(t, got.APIKey.GroupID) + require.Equal(t, int64(10), *got.APIKey.GroupID) + // Update is still called (current impl doesn't short-circuit on same group) + require.NotNil(t, apiKeyRepo.updated) + require.Equal(t, []string{"sk-test"}, cache.keys) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotFound(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test"} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{getErr: ErrGroupNotFound} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo} + + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(99)) + require.ErrorIs(t, err, ErrGroupNotFound) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotActive(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test"} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 5, Status: StatusDisabled}} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo} + + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(5)) + require.Error(t, err) + require.Equal(t, "GROUP_NOT_ACTIVE", infraerrors.Reason(err)) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_UpdateFails(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(3)} + repo := &apiKeyRepoStubForGroupUpdate{key: existing, updateErr: errors.New("db write error")} + svc := &adminServiceImpl{apiKeyRepo: repo} + + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0)) + require.Error(t, err) + require.Contains(t, err.Error(), "update api key") +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_NegativeGroupID(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test"} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo} + + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(-5)) + require.Error(t, err) + require.Equal(t, "INVALID_GROUP_ID", infraerrors.Reason(err)) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_PointerIsolation(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache} + + inputGID := int64(10) + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, &inputGID) + require.NoError(t, err) + require.NotNil(t, got.APIKey.GroupID) + // Mutating the input pointer must NOT affect the stored value + inputGID = 999 + require.Equal(t, int64(10), *got.APIKey.GroupID) + require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_NilCacheInvalidator(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test"} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 7, Status: StatusActive}} + // authCacheInvalidator is nil – should not panic + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(7)) + require.NoError(t, err) + require.NotNil(t, got.APIKey.GroupID) + require.Equal(t, int64(7), *got.APIKey.GroupID) +} + +// --------------------------------------------------------------------------- +// Tests: AllowedGroup auto-sync +// --------------------------------------------------------------------------- + +func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AddsAllowedGroup(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}} + userRepo := &userRepoStubForGroupUpdate{} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.NoError(t, err) + require.NotNil(t, got.APIKey.GroupID) + require.Equal(t, int64(10), *got.APIKey.GroupID) + // 验证 AddGroupToAllowedGroups 被调用,且参数正确 + require.True(t, userRepo.addGroupCalled) + require.Equal(t, int64(42), userRepo.addedUserID) + require.Equal(t, int64(10), userRepo.addedGroupID) + // 验证 result 标记了自动授权 + require.True(t, got.AutoGrantedGroupAccess) + require.NotNil(t, got.GrantedGroupID) + require.Equal(t, int64(10), *got.GrantedGroupID) + require.Equal(t, "Exclusive", got.GrantedGroupName) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupUpdate(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Public", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeStandard}} + userRepo := &userRepoStubForGroupUpdate{} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.NoError(t, err) + require.NotNil(t, got.APIKey.GroupID) + // 非专属分组不触发 AddGroupToAllowedGroups + require.False(t, userRepo.addGroupCalled) + require.False(t, got.AutoGrantedGroupAccess) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} + userRepo := &userRepoStubForGroupUpdate{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo} + + // 订阅类型分组应被阻止绑定 + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.Error(t, err) + require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err)) + require.False(t, userRepo.addGroupCalled) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AllowedGroupAddFails_ReturnsError(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}} + userRepo := &userRepoStubForGroupUpdate{addGroupErr: errors.New("db error")} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo} + + // 严格模式:AddGroupToAllowedGroups 失败时,整体操作报错 + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.Error(t, err) + require.Contains(t, err.Error(), "add group to user allowed groups") + require.True(t, userRepo.addGroupCalled) + // apiKey 不应被更新 + require.Nil(t, apiKeyRepo.updated) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind_NoAllowedGroupUpdate(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Exclusive"}} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + userRepo := &userRepoStubForGroupUpdate{} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, userRepo: userRepo, authCacheInvalidator: cache} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0)) + require.NoError(t, err) + require.Nil(t, got.APIKey.GroupID) + // 解绑时不修改 allowed_groups + require.False(t, userRepo.addGroupCalled) + require.False(t, got.AutoGrantedGroupAccess) +} diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index 662b95fb..4845d87c 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -15,6 +15,16 @@ type accountRepoStubForBulkUpdate struct { bulkUpdateErr error bulkUpdateIDs []int64 bindGroupErrByID map[int64]error + bindGroupsCalls []int64 + getByIDsAccounts []*Account + getByIDsErr error + getByIDsCalled bool + getByIDsIDs []int64 + getByIDAccounts map[int64]*Account + getByIDErrByID map[int64]error + getByIDCalled []int64 + listByGroupData map[int64][]Account + listByGroupErr map[int64]error } func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { @@ -26,12 +36,43 @@ func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64 } func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error { + s.bindGroupsCalls = append(s.bindGroupsCalls, accountID) if err, ok := s.bindGroupErrByID[accountID]; ok { return err } return nil } +func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) { + s.getByIDsCalled = true + s.getByIDsIDs = append([]int64{}, ids...) + if s.getByIDsErr != nil { + return nil, s.getByIDsErr + } + return s.getByIDsAccounts, nil +} + +func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Account, error) { + s.getByIDCalled = append(s.getByIDCalled, id) + if err, ok := s.getByIDErrByID[id]; ok { + return nil, err + } + if account, ok := s.getByIDAccounts[id]; ok { + return account, nil + } + return nil, errors.New("account not found") +} + +func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) { + if err, ok := s.listByGroupErr[groupID]; ok { + return nil, err + } + if rows, ok := s.listByGroupData[groupID]; ok { + return rows, nil + } + return nil, nil +} + // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { repo := &accountRepoStubForBulkUpdate{} @@ -59,7 +100,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { 2: errors.New("bind failed"), }, } - svc := &adminServiceImpl{accountRepo: repo} + svc := &adminServiceImpl{ + accountRepo: repo, + groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}}, + } groupIDs := []int64{10} schedulable := false @@ -78,3 +122,51 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { require.ElementsMatch(t, []int64{2}, result.FailedIDs) require.Len(t, result.Results, 3) } + +func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{} + svc := &adminServiceImpl{accountRepo: repo} + + groupIDs := []int64{10} + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1}, + GroupIDs: &groupIDs, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "group repository not configured") +} + +// TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict verifies +// that the global pre-check detects a conflict with existing group members and returns an +// error before any DB write is performed. +func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformAntigravity}, + }, + // Group 10 already contains an Anthropic account. + listByGroupData: map[int64][]Account{ + 10: {{ID: 99, Platform: PlatformAnthropic}}, + }, + } + svc := &adminServiceImpl{ + accountRepo: repo, + groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "target-group"}}, + } + + groupIDs := []int64{10} + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1}, + GroupIDs: &groupIDs, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "mixed channel") + // No BindGroups should have been called since the check runs before any write. + require.Empty(t, repo.bindGroupsCalls) +} diff --git a/backend/internal/service/admin_service_create_user_test.go b/backend/internal/service/admin_service_create_user_test.go index a0fe4d87..c5b1e38d 100644 --- a/backend/internal/service/admin_service_create_user_test.go +++ b/backend/internal/service/admin_service_create_user_test.go @@ -7,6 +7,7 @@ import ( "errors" "testing" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" ) @@ -65,3 +66,32 @@ func TestAdminService_CreateUser_CreateError(t *testing.T) { require.ErrorIs(t, err, createErr) require.Empty(t, repo.created) } + +func TestAdminService_CreateUser_AssignsDefaultSubscriptions(t *testing.T) { + repo := &userRepoStub{nextID: 21} + assigner := &defaultSubscriptionAssignerStub{} + cfg := &config.Config{ + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + } + settingService := NewSettingService(&settingRepoStub{values: map[string]string{ + SettingKeyDefaultSubscriptions: `[{"group_id":5,"validity_days":30}]`, + }}, cfg) + svc := &adminServiceImpl{ + userRepo: repo, + settingService: settingService, + defaultSubAssigner: assigner, + } + + _, err := svc.CreateUser(context.Background(), &CreateUserInput{ + Email: "new-user@test.com", + Password: "password", + }) + require.NoError(t, err) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(21), assigner.calls[0].UserID) + require.Equal(t, int64(5), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) +} diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index c775749d..bb906df5 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -93,6 +93,10 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID panic("unexpected RemoveGroupFromAllowedGroups call") } +func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + panic("unexpected AddGroupToAllowedGroups call") +} + func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { panic("unexpected UpdateTotpSecret call") } @@ -172,6 +176,10 @@ func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs [] panic("unexpected GetAccountIDsByGroupIDs call") } +func (s *groupRepoStub) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { + return nil +} + type proxyRepoStub struct { deleteErr error countErr error diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index d921a086..ef77a980 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -116,6 +116,10 @@ func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []i panic("unexpected GetAccountIDsByGroupIDs call") } +func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error { + return nil +} + // TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递 func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) { repo := &groupRepoStubForAdmin{} @@ -395,6 +399,10 @@ func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Contex panic("unexpected GetAccountIDsByGroupIDs call") } +func (s *groupRepoStubForFallbackCycle) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error { + return nil +} + type groupRepoStubForInvalidRequestFallback struct { groups map[int64]*Group created *Group @@ -466,6 +474,10 @@ func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.C panic("unexpected BindAccountsToGroup call") } +func (s *groupRepoStubForInvalidRequestFallback) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error { + return nil +} + func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) { fallbackID := int64(10) repo := &groupRepoStubForInvalidRequestFallback{ diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go new file mode 100644 index 00000000..8b50530a --- /dev/null +++ b/backend/internal/service/admin_service_list_users_test.go @@ -0,0 +1,106 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type userRepoStubForListUsers struct { + userRepoStub + users []User + err error +} + +func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) { + if s.err != nil { + return nil, nil, s.err + } + out := make([]User, len(s.users)) + copy(out, s.users) + return out, &pagination.PaginationResult{ + Total: int64(len(out)), + Page: params.Page, + PageSize: params.PageSize, + }, nil +} + +type userGroupRateRepoStubForListUsers struct { + batchCalls int + singleCall []int64 + + batchErr error + batchData map[int64]map[int64]float64 + + singleErr map[int64]error + singleData map[int64]map[int64]float64 +} + +func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) { + s.batchCalls++ + if s.batchErr != nil { + return nil, s.batchErr + } + return s.batchData, nil +} + +func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) { + s.singleCall = append(s.singleCall, userID) + if err, ok := s.singleErr[userID]; ok { + return nil, err + } + if rates, ok := s.singleData[userID]; ok { + return rates, nil + } + return map[int64]float64{}, nil +} + +func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) { + panic("unexpected GetByUserAndGroup call") +} + +func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error { + panic("unexpected SyncUserGroupRates call") +} + +func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error { + panic("unexpected DeleteByGroupID call") +} + +func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error { + panic("unexpected DeleteByUserID call") +} + +func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) { + userRepo := &userRepoStubForListUsers{ + users: []User{ + {ID: 101, Username: "u1"}, + {ID: 202, Username: "u2"}, + }, + } + rateRepo := &userGroupRateRepoStubForListUsers{ + batchErr: errors.New("batch unavailable"), + singleData: map[int64]map[int64]float64{ + 101: {11: 1.1}, + 202: {22: 2.2}, + }, + } + svc := &adminServiceImpl{ + userRepo: userRepo, + userGroupRateRepo: rateRepo, + } + + users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}) + require.NoError(t, err) + require.Equal(t, int64(2), total) + require.Len(t, users, 2) + require.Equal(t, 1, rateRepo.batchCalls) + require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall) + require.Equal(t, 1.1, users[0].GroupRates[11]) + require.Equal(t, 2.2, users[1].GroupRates[22]) +} diff --git a/backend/internal/service/admin_service_proxy_quality_test.go b/backend/internal/service/admin_service_proxy_quality_test.go new file mode 100644 index 00000000..5a43cd9c --- /dev/null +++ b/backend/internal/service/admin_service_proxy_quality_test.go @@ -0,0 +1,95 @@ +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) { + result := &ProxyQualityCheckResult{ + PassedCount: 2, + WarnCount: 1, + FailedCount: 1, + ChallengeCount: 1, + } + + finalizeProxyQualityResult(result) + + require.Equal(t, 38, result.Score) + require.Equal(t, "F", result.Grade) + require.Contains(t, result.Summary, "通过 2 项") + require.Contains(t, result.Summary, "告警 1 项") + require.Contains(t, result.Summary, "失败 1 项") + require.Contains(t, result.Summary, "挑战 1 项") +} + +func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("cf-ray", "test-ray-123") + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte("Just a moment...")) + })) + defer server.Close() + + target := proxyQualityTarget{ + Target: "sora", + URL: server.URL, + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + } + + item := runProxyQualityTarget(context.Background(), server.Client(), target) + require.Equal(t, "challenge", item.Status) + require.Equal(t, http.StatusForbidden, item.HTTPStatus) + require.Equal(t, "test-ray-123", item.CFRay) +} + +func TestRunProxyQualityTarget_AllowedStatusPass(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models":[]}`)) + })) + defer server.Close() + + target := proxyQualityTarget{ + Target: "gemini", + URL: server.URL, + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusOK: {}, + }, + } + + item := runProxyQualityTarget(context.Background(), server.Client(), target) + require.Equal(t, "pass", item.Status) + require.Equal(t, http.StatusOK, item.HTTPStatus) +} + +func TestRunProxyQualityTarget_AllowedStatusWarnForUnauthorized(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + })) + defer server.Close() + + target := proxyQualityTarget{ + Target: "openai", + URL: server.URL, + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + } + + item := runProxyQualityTarget(context.Background(), server.Client(), target) + require.Equal(t, "warn", item.Status) + require.Equal(t, http.StatusUnauthorized, item.HTTPStatus) + require.Contains(t, item.Message, "目标可达") +} diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go index d661b710..ff58fd01 100644 --- a/backend/internal/service/admin_service_search_test.go +++ b/backend/internal/service/admin_service_search_test.go @@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct { listWithFiltersErr error } -func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { s.listWithFiltersCalls++ s.listWithFiltersParams = params s.listWithFiltersPlatform = platform @@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc") + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0) require.NoError(t, err) require.Equal(t, int64(10), total) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) diff --git a/backend/internal/service/anthropic_session.go b/backend/internal/service/anthropic_session.go new file mode 100644 index 00000000..26544c68 --- /dev/null +++ b/backend/internal/service/anthropic_session.go @@ -0,0 +1,79 @@ +package service + +import ( + "encoding/json" + "strings" + "time" +) + +// Anthropic 会话 Fallback 相关常量 +const ( + // anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟) + anthropicSessionTTLSeconds = 300 + + // anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀 + anthropicDigestSessionKeyPrefix = "anthropic:digest:" +) + +// AnthropicSessionTTL 返回 Anthropic 会话缓存 TTL +func AnthropicSessionTTL() time.Duration { + return anthropicSessionTTLSeconds * time.Second +} + +// BuildAnthropicDigestChain 根据 Anthropic 请求生成摘要链 +// 格式: s:-u:-a:-u:-... +// s = system, u = user, a = assistant +func BuildAnthropicDigestChain(parsed *ParsedRequest) string { + if parsed == nil { + return "" + } + + var parts []string + + // 1. system prompt + if parsed.System != nil { + systemData, _ := json.Marshal(parsed.System) + if len(systemData) > 0 && string(systemData) != "null" { + parts = append(parts, "s:"+shortHash(systemData)) + } + } + + // 2. messages + for _, msg := range parsed.Messages { + msgMap, ok := msg.(map[string]any) + if !ok { + continue + } + role, _ := msgMap["role"].(string) + prefix := rolePrefix(role) + content := msgMap["content"] + contentData, _ := json.Marshal(content) + parts = append(parts, prefix+":"+shortHash(contentData)) + } + + return strings.Join(parts, "-") +} + +// rolePrefix 将 Anthropic 的 role 映射为单字符前缀 +func rolePrefix(role string) string { + switch role { + case "assistant": + return "a" + default: + return "u" + } +} + +// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey +// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey +func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string { + prefix := prefixHash + if len(prefixHash) >= 8 { + prefix = prefixHash[:8] + } + uuidPart := uuid + if len(uuid) >= 8 { + uuidPart = uuid[:8] + } + return anthropicDigestSessionKeyPrefix + prefix + ":" + uuidPart +} diff --git a/backend/internal/service/anthropic_session_test.go b/backend/internal/service/anthropic_session_test.go new file mode 100644 index 00000000..10406643 --- /dev/null +++ b/backend/internal/service/anthropic_session_test.go @@ -0,0 +1,320 @@ +package service + +import ( + "strings" + "testing" +) + +func TestBuildAnthropicDigestChain_NilRequest(t *testing.T) { + result := BuildAnthropicDigestChain(nil) + if result != "" { + t.Errorf("expected empty string for nil request, got: %s", result) + } +} + +func TestBuildAnthropicDigestChain_EmptyMessages(t *testing.T) { + parsed := &ParsedRequest{ + Messages: []any{}, + } + result := BuildAnthropicDigestChain(parsed) + if result != "" { + t.Errorf("expected empty string for empty messages, got: %s", result) + } +} + +func TestBuildAnthropicDigestChain_SingleUserMessage(t *testing.T) { + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 1 { + t.Fatalf("expected 1 part, got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "u:") { + t.Errorf("expected prefix 'u:', got: %s", parts[0]) + } +} + +func TestBuildAnthropicDigestChain_UserAndAssistant(t *testing.T) { + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi there"}, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 2 { + t.Fatalf("expected 2 parts, got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "u:") { + t.Errorf("part[0] expected prefix 'u:', got: %s", parts[0]) + } + if !strings.HasPrefix(parts[1], "a:") { + t.Errorf("part[1] expected prefix 'a:', got: %s", parts[1]) + } +} + +func TestBuildAnthropicDigestChain_WithSystemString(t *testing.T) { + parsed := &ParsedRequest{ + System: "You are a helpful assistant", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 2 { + t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "s:") { + t.Errorf("part[0] expected prefix 's:', got: %s", parts[0]) + } + if !strings.HasPrefix(parts[1], "u:") { + t.Errorf("part[1] expected prefix 'u:', got: %s", parts[1]) + } +} + +func TestBuildAnthropicDigestChain_WithSystemContentBlocks(t *testing.T) { + parsed := &ParsedRequest{ + System: []any{ + map[string]any{"type": "text", "text": "You are a helpful assistant"}, + }, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 2 { + t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "s:") { + t.Errorf("part[0] expected prefix 's:', got: %s", parts[0]) + } +} + +func TestBuildAnthropicDigestChain_ConversationPrefixRelationship(t *testing.T) { + // 核心测试:验证对话增长时链的前缀关系 + // 上一轮的完整链一定是下一轮链的前缀 + system := "You are a helpful assistant" + + // 第 1 轮: system + user + round1 := &ParsedRequest{ + System: system, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + chain1 := BuildAnthropicDigestChain(round1) + + // 第 2 轮: system + user + assistant + user + round2 := &ParsedRequest{ + System: system, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi there"}, + map[string]any{"role": "user", "content": "how are you?"}, + }, + } + chain2 := BuildAnthropicDigestChain(round2) + + // 第 3 轮: system + user + assistant + user + assistant + user + round3 := &ParsedRequest{ + System: system, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi there"}, + map[string]any{"role": "user", "content": "how are you?"}, + map[string]any{"role": "assistant", "content": "I'm doing well"}, + map[string]any{"role": "user", "content": "great"}, + }, + } + chain3 := BuildAnthropicDigestChain(round3) + + t.Logf("Chain1: %s", chain1) + t.Logf("Chain2: %s", chain2) + t.Logf("Chain3: %s", chain3) + + // chain1 是 chain2 的前缀 + if !strings.HasPrefix(chain2, chain1) { + t.Errorf("chain1 should be prefix of chain2:\n chain1: %s\n chain2: %s", chain1, chain2) + } + + // chain2 是 chain3 的前缀 + if !strings.HasPrefix(chain3, chain2) { + t.Errorf("chain2 should be prefix of chain3:\n chain2: %s\n chain3: %s", chain2, chain3) + } + + // chain1 也是 chain3 的前缀(传递性) + if !strings.HasPrefix(chain3, chain1) { + t.Errorf("chain1 should be prefix of chain3:\n chain1: %s\n chain3: %s", chain1, chain3) + } +} + +func TestBuildAnthropicDigestChain_DifferentSystemProducesDifferentChain(t *testing.T) { + parsed1 := &ParsedRequest{ + System: "System A", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + parsed2 := &ParsedRequest{ + System: "System B", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + chain1 := BuildAnthropicDigestChain(parsed1) + chain2 := BuildAnthropicDigestChain(parsed2) + + if chain1 == chain2 { + t.Error("Different system prompts should produce different chains") + } + + // 但 user 部分的 hash 应该相同 + parts1 := splitChain(chain1) + parts2 := splitChain(chain2) + if parts1[1] != parts2[1] { + t.Error("Same user message should produce same hash regardless of system") + } +} + +func TestBuildAnthropicDigestChain_DifferentContentProducesDifferentChain(t *testing.T) { + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "ORIGINAL reply"}, + map[string]any{"role": "user", "content": "next"}, + }, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "TAMPERED reply"}, + map[string]any{"role": "user", "content": "next"}, + }, + } + + chain1 := BuildAnthropicDigestChain(parsed1) + chain2 := BuildAnthropicDigestChain(parsed2) + + if chain1 == chain2 { + t.Error("Different content should produce different chains") + } + + parts1 := splitChain(chain1) + parts2 := splitChain(chain2) + // 第一个 user message hash 应该相同 + if parts1[0] != parts2[0] { + t.Error("First user message hash should be the same") + } + // assistant reply hash 应该不同 + if parts1[1] == parts2[1] { + t.Error("Assistant reply hash should differ") + } +} + +func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) { + parsed := &ParsedRequest{ + System: "test system", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi"}, + }, + } + + chain1 := BuildAnthropicDigestChain(parsed) + chain2 := BuildAnthropicDigestChain(parsed) + + if chain1 != chain2 { + t.Errorf("BuildAnthropicDigestChain not deterministic: %s vs %s", chain1, chain2) + } +} + +func TestGenerateAnthropicDigestSessionKey(t *testing.T) { + tests := []struct { + name string + prefixHash string + uuid string + want string + }{ + { + name: "normal 16 char hash with uuid", + prefixHash: "abcdefgh12345678", + uuid: "550e8400-e29b-41d4-a716-446655440000", + want: "anthropic:digest:abcdefgh:550e8400", + }, + { + name: "exactly 8 chars", + prefixHash: "12345678", + uuid: "abcdefgh", + want: "anthropic:digest:12345678:abcdefgh", + }, + { + name: "short values", + prefixHash: "abc", + uuid: "xyz", + want: "anthropic:digest:abc:xyz", + }, + { + name: "empty values", + prefixHash: "", + uuid: "", + want: "anthropic:digest::", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GenerateAnthropicDigestSessionKey(tt.prefixHash, tt.uuid) + if got != tt.want { + t.Errorf("GenerateAnthropicDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want) + } + }) + } + + // 验证不同 uuid 产生不同 sessionKey + t.Run("different uuid different key", func(t *testing.T) { + hash := "sameprefix123456" + result1 := GenerateAnthropicDigestSessionKey(hash, "uuid0001-session-a") + result2 := GenerateAnthropicDigestSessionKey(hash, "uuid0002-session-b") + if result1 == result2 { + t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2) + } + }) +} + +func TestAnthropicSessionTTL(t *testing.T) { + ttl := AnthropicSessionTTL() + if ttl.Seconds() != 300 { + t.Errorf("expected 300 seconds, got: %v", ttl.Seconds()) + } +} + +func TestBuildAnthropicDigestChain_ContentBlocks(t *testing.T) { + // 测试 content 为 content blocks 数组的情况 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "describe this image"}, + map[string]any{"type": "image", "source": map[string]any{"type": "base64"}}, + }, + }, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 1 { + t.Fatalf("expected 1 part, got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "u:") { + t.Errorf("expected prefix 'u:', got: %s", parts[0]) + } +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 3d3c9cca..96ff3354 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -9,18 +9,22 @@ import ( "fmt" "io" "log" + "log/slog" mathrand "math/rand" "net" "net/http" "os" "strconv" "strings" + "sync" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/tidwall/gjson" ) const ( @@ -35,9 +39,15 @@ const ( // - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号 antigravityRateLimitThreshold = 7 * time.Second antigravitySmartRetryMinWait = 1 * time.Second // 智能重试最小等待时间 - antigravitySmartRetryMaxAttempts = 3 // 智能重试最大次数 + antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待) antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用) + // MODEL_CAPACITY_EXHAUSTED 专用重试参数 + // 模型容量不足时,所有账号共享同一容量池,切换账号无意义 + // 使用固定 1s 间隔重试,最多重试 60 次 + antigravityModelCapacityRetryMaxAttempts = 60 + antigravityModelCapacityRetryWait = 1 * time.Second + // Google RPC 状态和类型常量 googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED" googleRPCStatusUnavailable = "UNAVAILABLE" @@ -45,6 +55,22 @@ const ( googleRPCTypeErrorInfo = "type.googleapis.com/google.rpc.ErrorInfo" googleRPCReasonModelCapacityExhausted = "MODEL_CAPACITY_EXHAUSTED" googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED" + + // 单账号 503 退避重试:Service 层原地重试的最大次数 + // 在 handleSmartRetry 中,对于 shouldRateLimitModel(长延迟 ≥ 7s)的情况, + // 多账号模式下会设限流+切换账号;但单账号模式下改为原地等待+重试。 + antigravitySingleAccountSmartRetryMaxAttempts = 3 + + // 单账号 503 退避重试:原地重试时单次最大等待时间 + // 防止上游返回过长的 retryDelay 导致请求卡住太久 + antigravitySingleAccountSmartRetryMaxWait = 15 * time.Second + + // 单账号 503 退避重试:原地重试的总累计等待时间上限 + // 超过此上限将不再重试,直接返回 503 + antigravitySingleAccountSmartRetryTotalMaxWait = 30 * time.Second + + // MODEL_CAPACITY_EXHAUSTED 全局去重:重试全部失败后的 cooldown 时间 + antigravityModelCapacityCooldown = 10 * time.Second ) // antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) @@ -53,8 +79,14 @@ var antigravityPassthroughErrorMessages = []string{ "prompt is too long", } +// MODEL_CAPACITY_EXHAUSTED 全局去重:避免多个并发请求同时对同一模型进行容量耗尽重试 +var ( + modelCapacityExhaustedMu sync.RWMutex + modelCapacityExhaustedUntil = make(map[string]time.Time) // modelName -> cooldown until +) + const ( - antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL" + antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL" antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS" ) @@ -100,12 +132,11 @@ type antigravityRetryLoopParams struct { accessToken string action string body []byte - quotaScope AntigravityQuotaScope c *gin.Context httpUpstream HTTPUpstream settingService *SettingService accountRepo AccountRepository // 用于智能重试的模型级别限流 - handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult + handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult requestedModel string // 用于限流检查的原始请求模型 isStickySession bool // 是否为粘性会话(用于账号切换时的缓存计费判断) groupID int64 // 用于模型级限流时清除粘性会话 @@ -117,6 +148,20 @@ type antigravityRetryLoopResult struct { resp *http.Response } +// resolveAntigravityForwardBaseURL 解析转发用 base URL。 +// 默认使用 daily(ForwardBaseURLs 的首个地址);当环境变量为 prod 时使用第二个地址。 +func resolveAntigravityForwardBaseURL() string { + baseURLs := antigravity.ForwardBaseURLs() + if len(baseURLs) == 0 { + return "" + } + mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv))) + if mode == "prod" && len(baseURLs) > 1 { + return baseURLs[1] + } + return baseURLs[0] +} + // smartRetryAction 智能重试的处理结果 type smartRetryAction int @@ -139,22 +184,33 @@ type smartRetryResult struct { func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParams, resp *http.Response, respBody []byte, baseURL string, urlIdx int, availableURLs []string) *smartRetryResult { // "Resource has been exhausted" 是 URL 级别限流,切换 URL(仅 429) if resp.StatusCode == http.StatusTooManyRequests && isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 { - log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) + logger.LegacyPrintf("service.antigravity_gateway", "%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) return &smartRetryResult{action: smartRetryActionContinueURL} } // 判断是否触发智能重试 - shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName := shouldTriggerAntigravitySmartRetry(p.account, respBody) + shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody) // 情况1: retryDelay >= 阈值,限流模型并切换账号 if shouldRateLimitModel { - log.Printf("%s status=%d oauth_long_delay model=%s account=%d (model rate limit, switch account)", - p.prefix, resp.StatusCode, modelName, p.account.ID) + // 单账号 503 退避重试模式:不设限流、不切换账号,改为原地等待+重试 + // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 + // 多账号场景下切换账号是最优选择,但单账号场景下设限流毫无意义(只会导致双重等待)。 + if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) { + return s.handleSingleAccountRetryInPlace(p, resp, respBody, baseURL, waitDuration, modelName) + } - resetAt := time.Now().Add(antigravityDefaultRateLimitDuration) + rateLimitDuration := waitDuration + if rateLimitDuration <= 0 { + rateLimitDuration = antigravityDefaultRateLimitDuration + } + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)", + p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration, truncateForLog(respBody, 200)) + + resetAt := time.Now().Add(rateLimitDuration) if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) { - p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) - log.Printf("%s status=%d rate_limited account=%d (no scope mapping)", p.prefix, resp.StatusCode, p.account.ID) + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID) } else { s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt) } @@ -170,27 +226,55 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam } } - // 情况2: retryDelay < 阈值,智能重试(最多 antigravitySmartRetryMaxAttempts 次) + // 情况2: retryDelay < 阈值(或 MODEL_CAPACITY_EXHAUSTED),智能重试 if shouldSmartRetry { var lastRetryResp *http.Response var lastRetryBody []byte - for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ { - log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d", - p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID) + // MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数(60 次,固定 1s 间隔) + maxAttempts := antigravitySmartRetryMaxAttempts + if isModelCapacityExhausted { + maxAttempts = antigravityModelCapacityRetryMaxAttempts + waitDuration = antigravityModelCapacityRetryWait + // 全局去重:如果其他 goroutine 已在重试同一模型且尚在 cooldown 中,直接返回 503 + if modelName != "" { + modelCapacityExhaustedMu.RLock() + cooldownUntil, exists := modelCapacityExhaustedUntil[modelName] + modelCapacityExhaustedMu.RUnlock() + if exists && time.Now().Before(cooldownUntil) { + log.Printf("%s status=%d model_capacity_exhausted_dedup model=%s account=%d cooldown_until=%v (skip retry)", + p.prefix, resp.StatusCode, modelName, p.account.ID, cooldownUntil.Format("15:04:05")) + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + resp: &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + }, + } + } + } + } + + for attempt := 1; attempt <= maxAttempts; attempt++ { + log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d", + p.prefix, resp.StatusCode, attempt, maxAttempts, waitDuration, modelName, p.account.ID) + + timer := time.NewTimer(waitDuration) select { case <-p.ctx.Done(): + timer.Stop() log.Printf("%s status=context_canceled_during_smart_retry", p.prefix) return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} - case <-time.After(waitDuration): + case <-timer.C: } // 智能重试:创建新请求 retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body) if err != nil { - log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err) - p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=smart_retry_request_build_failed error=%v", p.prefix, err) + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) return &smartRetryResult{ action: smartRetryActionBreakWithResp, resp: &http.Response{ @@ -203,13 +287,19 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency) if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable { - log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts) + log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts) + // 重试成功,清除 MODEL_CAPACITY_EXHAUSTED cooldown + if isModelCapacityExhausted && modelName != "" { + modelCapacityExhaustedMu.Lock() + delete(modelCapacityExhaustedUntil, modelName) + modelCapacityExhaustedMu.Unlock() + } return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp} } // 网络错误时,继续重试 if retryErr != nil || retryResp == nil { - log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr) + log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, maxAttempts, retryErr) continue } @@ -219,34 +309,84 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam } lastRetryResp = retryResp if retryResp != nil { - lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10)) _ = retryResp.Body.Close() } - // 解析新的重试信息,用于下次重试的等待时间 - if attempt < antigravitySmartRetryMaxAttempts && lastRetryBody != nil { - newShouldRetry, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody) + // 解析新的重试信息,用于下次重试的等待时间(MODEL_CAPACITY_EXHAUSTED 使用固定循环,跳过) + if !isModelCapacityExhausted && attempt < maxAttempts && lastRetryBody != nil { + newShouldRetry, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody) if newShouldRetry && newWaitDuration > 0 { waitDuration = newWaitDuration } } } - // 所有重试都失败,限流当前模型并切换账号 - log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d (switch account)", - p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID) + // 所有重试都失败 + rateLimitDuration := waitDuration + if rateLimitDuration <= 0 { + rateLimitDuration = antigravityDefaultRateLimitDuration + } + retryBody := lastRetryBody + if retryBody == nil { + retryBody = respBody + } - resetAt := time.Now().Add(antigravityDefaultRateLimitDuration) + // MODEL_CAPACITY_EXHAUSTED:模型容量不足,切换账号无意义 + // 直接返回上游错误响应,不设置模型限流,不切换账号 + if isModelCapacityExhausted { + // 设置 cooldown,让后续请求快速失败,避免重复重试 + if modelName != "" { + modelCapacityExhaustedMu.Lock() + modelCapacityExhaustedUntil[modelName] = time.Now().Add(antigravityModelCapacityCooldown) + modelCapacityExhaustedMu.Unlock() + } + log.Printf("%s status=%d smart_retry_exhausted_model_capacity attempts=%d model=%s account=%d body=%s (model capacity exhausted, not switching account)", + p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200)) + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + resp: &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(retryBody)), + }, + } + } + + // 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号, + // 直接返回 503 让 Handler 层的单账号退避循环做最终处理。 + if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) { + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d smart_retry_exhausted_single_account attempts=%d model=%s account=%d body=%s (return 503 directly)", + p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200)) + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + resp: &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(retryBody)), + }, + } + } + + log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)", + p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200)) + + resetAt := time.Now().Add(rateLimitDuration) if p.accountRepo != nil && modelName != "" { if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil { - log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err) } else { - log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", - p.prefix, resp.StatusCode, modelName, p.account.ID, antigravityDefaultRateLimitDuration) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", + p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration) s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt) } } + // 清除粘性会话绑定,避免下次请求仍命中限流账号 + if s.cache != nil && p.sessionHash != "" { + _ = s.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash) + } + // 返回账号切换信号,让上层切换账号重试 return &smartRetryResult{ action: smartRetryActionBreakWithResp, @@ -262,24 +402,149 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam return &smartRetryResult{action: smartRetryActionContinue} } +// handleSingleAccountRetryInPlace 单账号 503 退避重试的原地重试逻辑。 +// +// 在多账号场景下,收到 503 + 长 retryDelay(≥ 7s)时会设置模型限流 + 切换账号; +// 但在单账号场景下,设限流毫无意义(因为切换回来的还是同一个账号,还要等限流过期)。 +// 此方法改为在 Service 层原地等待 + 重试,避免双重等待问题: +// +// 旧流程:Service 设限流 → Handler 退避等待 → Service 等限流过期 → 再请求(总耗时 = 退避 + 限流) +// 新流程:Service 直接等 retryDelay → 重试 → 成功/再等 → 重试...(总耗时 ≈ 实际 retryDelay × 重试次数) +// +// 约束: +// - 单次等待不超过 antigravitySingleAccountSmartRetryMaxWait +// - 总累计等待不超过 antigravitySingleAccountSmartRetryTotalMaxWait +// - 最多重试 antigravitySingleAccountSmartRetryMaxAttempts 次 +func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( + p antigravityRetryLoopParams, + resp *http.Response, + respBody []byte, + baseURL string, + waitDuration time.Duration, + modelName string, +) *smartRetryResult { + // 限制单次等待时间 + if waitDuration > antigravitySingleAccountSmartRetryMaxWait { + waitDuration = antigravitySingleAccountSmartRetryMaxWait + } + if waitDuration < antigravitySmartRetryMinWait { + waitDuration = antigravitySmartRetryMinWait + } + + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_in_place model=%s account=%d upstream_retry_delay=%v (retrying in-place instead of rate-limiting)", + p.prefix, resp.StatusCode, modelName, p.account.ID, waitDuration) + + var lastRetryResp *http.Response + var lastRetryBody []byte + totalWaited := time.Duration(0) + + for attempt := 1; attempt <= antigravitySingleAccountSmartRetryMaxAttempts; attempt++ { + // 检查累计等待是否超限 + if totalWaited+waitDuration > antigravitySingleAccountSmartRetryTotalMaxWait { + remaining := antigravitySingleAccountSmartRetryTotalMaxWait - totalWaited + if remaining <= 0 { + logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: total_wait_exceeded total=%v max=%v, giving up", + p.prefix, totalWaited, antigravitySingleAccountSmartRetryTotalMaxWait) + break + } + waitDuration = remaining + } + + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d", + p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID) + + timer := time.NewTimer(waitDuration) + select { + case <-p.ctx.Done(): + timer.Stop() + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_single_account_retry", p.prefix) + return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} + case <-timer.C: + } + totalWaited += waitDuration + + // 创建新请求 + retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body) + if err != nil { + logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: request_build_failed error=%v", p.prefix, err) + break + } + + retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency) + if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable { + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_success attempt=%d/%d total_waited=%v", + p.prefix, retryResp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited) + // 关闭之前的响应 + if lastRetryResp != nil { + _ = lastRetryResp.Body.Close() + } + return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp} + } + + // 网络错误时继续重试 + if retryErr != nil || retryResp == nil { + logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: network_error attempt=%d/%d error=%v", + p.prefix, attempt, antigravitySingleAccountSmartRetryMaxAttempts, retryErr) + continue + } + + // 关闭之前的响应 + if lastRetryResp != nil { + _ = lastRetryResp.Body.Close() + } + lastRetryResp = retryResp + lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10)) + _ = retryResp.Body.Close() + + // 解析新的重试信息,更新下次等待时间 + if attempt < antigravitySingleAccountSmartRetryMaxAttempts && lastRetryBody != nil { + _, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody) + if newWaitDuration > 0 { + waitDuration = newWaitDuration + if waitDuration > antigravitySingleAccountSmartRetryMaxWait { + waitDuration = antigravitySingleAccountSmartRetryMaxWait + } + if waitDuration < antigravitySmartRetryMinWait { + waitDuration = antigravitySmartRetryMinWait + } + } + } + } + + // 所有重试都失败,不设限流,直接返回 503 + // Handler 层的单账号退避循环会做最终处理 + retryBody := lastRetryBody + if retryBody == nil { + retryBody = respBody + } + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_exhausted attempts=%d total_waited=%v model=%s account=%d body=%s (return 503 directly)", + p.prefix, resp.StatusCode, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited, modelName, p.account.ID, truncateForLog(retryBody, 200)) + + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + resp: &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(retryBody)), + }, + } +} + // antigravityRetryLoop 执行带 URL fallback 的重试循环 func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { - // 预检查:如果账号已限流,根据剩余时间决定等待或切换 + // 预检查:如果账号已限流,直接返回切换信号 if p.requestedModel != "" { if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 { - if remaining < antigravityRateLimitThreshold { - // 限流剩余时间较短,等待后继续 - log.Printf("%s pre_check: rate_limit_wait remaining=%v model=%s account=%d", + // 单账号 503 退避重试模式:跳过限流预检查,直接发请求。 + // 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。 + // 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace + // 会在 Service 层原地等待+重试,不需要在预检查这里等。 + if isSingleAccountRetry(p.ctx) { + logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)", p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) - select { - case <-p.ctx.Done(): - return nil, p.ctx.Err() - case <-time.After(remaining): - } } else { - // 限流剩余时间较长,返回账号切换信号 - log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", - p.prefix, remaining.Truncate(time.Second), p.requestedModel, p.account.ID) + logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", + p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) return nil, &AntigravityAccountSwitchError{ OriginalAccountID: p.account.ID, RateLimitedModel: p.requestedModel, @@ -289,10 +554,11 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP } } - availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() - if len(availableURLs) == 0 { - availableURLs = antigravity.BaseURLs + baseURL := resolveAntigravityForwardBaseURL() + if baseURL == "" { + return nil, errors.New("no antigravity forward base url configured") } + availableURLs := []string{baseURL} var resp *http.Response var usedBaseURL string @@ -314,7 +580,7 @@ urlFallbackLoop: for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { select { case <-p.ctx.Done(): - log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err()) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled error=%v", p.prefix, p.ctx.Err()) return nil, p.ctx.Err() default: } @@ -344,103 +610,118 @@ urlFallbackLoop: Message: safeErr, }) if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) + logger.LegacyPrintf("service.antigravity_gateway", "%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) continue urlFallbackLoop } if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() } continue } - log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=request_failed retries_exhausted error=%v", p.prefix, err) setOpsUpstreamError(p.c, 0, safeErr, "") return nil, fmt.Errorf("upstream request failed after retries: %w", err) } - // 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流 - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + // 统一处理错误响应 + if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() - // 尝试智能重试处理(OAuth 账号专用) - smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs) - switch smartResult.action { - case smartRetryActionContinueURL: - continue urlFallbackLoop - case smartRetryActionBreakWithResp: - if smartResult.err != nil { - return nil, smartResult.err + // ★ 统一入口:自定义错误码 + 临时不可调度 + if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled { + if policyErr != nil { + return nil, policyErr } - // 模型限流时返回切换账号信号 - if smartResult.switchError != nil { - return nil, smartResult.switchError + resp = &http.Response{ + StatusCode: outStatus, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), } - resp = smartResult.resp break urlFallbackLoop } - // smartRetryActionContinue: 继续默认重试逻辑 - // 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败) - if attempt < antigravityMaxRetries { - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ - Platform: p.account.Platform, - AccountID: p.account.ID, - AccountName: p.account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: getUpstreamDetail(respBody), - }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) - if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) - return nil, p.ctx.Err() + // 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流 + if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + // 尝试智能重试处理(OAuth 账号专用) + smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs) + switch smartResult.action { + case smartRetryActionContinueURL: + continue urlFallbackLoop + case smartRetryActionBreakWithResp: + if smartResult.err != nil { + return nil, smartResult.err + } + // 模型限流时返回切换账号信号 + if smartResult.switchError != nil { + return nil, smartResult.switchError + } + resp = smartResult.resp + break urlFallbackLoop } - continue + // smartRetryActionContinue: 继续默认重试逻辑 + + // 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败) + if attempt < antigravityMaxRetries { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: upstreamMsg, + Detail: getUpstreamDetail(respBody), + }) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + + // 重试用尽,标记账户限流 + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop } - // 重试用尽,标记账户限流 - p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) - log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break urlFallbackLoop - } - - // 其他可重试错误(不包括 429 和 503,因为上面已处理) - if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ - Platform: p.account.Platform, - AccountID: p.account.ID, - AccountName: p.account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: getUpstreamDetail(respBody), - }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) - if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) - return nil, p.ctx.Err() - } - continue - } + // 其他可重试错误(500/502/504/529,不包括 429 和 503) + if shouldRetryAntigravityError(resp.StatusCode) { + if attempt < antigravityMaxRetries { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: upstreamMsg, + Detail: getUpstreamDetail(respBody), + }) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + } + + // 其他 4xx 错误或重试用尽,直接返回 resp = &http.Response{ StatusCode: resp.StatusCode, Header: resp.Header.Clone(), @@ -449,6 +730,7 @@ urlFallbackLoop: break urlFallbackLoop } + // 成功响应(< 400) break urlFallbackLoop } } @@ -581,6 +863,32 @@ func (s *AntigravityGatewayService) getUpstreamErrorDetail(body []byte) string { return truncateString(string(body), maxBytes) } +// checkErrorPolicy nil 安全的包装 +func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, account *Account, statusCode int, body []byte) ErrorPolicyResult { + if s.rateLimitService == nil { + return ErrorPolicyNone + } + return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body) +} + +// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环及应返回的状态码。 +// ErrorPolicySkipped 时 outStatus 为 500(前端约定:未命中的错误返回 500)。 +func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, outStatus int, retErr error) { + switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) { + case ErrorPolicySkipped: + return true, http.StatusInternalServerError, nil + case ErrorPolicyMatched: + _ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody, + p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) + return true, statusCode, nil + case ErrorPolicyTempUnscheduled: + slog.Info("temp_unschedulable_matched", + "prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID) + return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession} + } + return false, statusCode, nil +} + // mapAntigravityModel 获取映射后的模型名 // 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底(DefaultAntigravityModelMapping) // 注意:返回空字符串表示模型不被支持,调度时会过滤掉该账号 @@ -650,6 +958,7 @@ type TestConnectionResult struct { // TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) // 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { + // 获取 token if s.tokenProvider == nil { return nil, errors.New("antigravity token provider not configured") @@ -687,11 +996,11 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account proxyURL = account.Proxy.URL() } - // URL fallback 循环 - availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() - if len(availableURLs) == 0 { - availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 + baseURL := resolveAntigravityForwardBaseURL() + if baseURL == "" { + return nil, errors.New("no antigravity forward base url configured") } + availableURLs := []string{baseURL} var lastErr error for urlIdx, baseURL := range availableURLs { @@ -703,14 +1012,14 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account } // 调试日志:Test 请求信息 - log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) // 发送请求 resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) if err != nil { lastErr = fmt.Errorf("请求失败: %w", err) if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) continue } return nil, lastErr @@ -725,7 +1034,7 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account // 检查是否需要 URL 降级 if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { - log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) continue } @@ -934,16 +1243,12 @@ func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model strin } // unwrapV1InternalResponse 解包 v1internal 响应 +// 使用 gjson 零拷贝提取 response 字段,避免 Unmarshal+Marshal 双重开销 func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) { - var outer map[string]any - if err := json.Unmarshal(body, &outer); err != nil { - return nil, err + result := gjson.GetBytes(body, "response") + if result.Exists() { + return []byte(result.Raw), nil } - - if resp, ok := outer["response"]; ok { - return json.Marshal(resp) - } - return body, nil } @@ -964,8 +1269,24 @@ func isModelNotFoundError(statusCode int, body []byte) bool { } // Forward 转发 Claude 协议请求(Claude → Gemini 转换) +// +// 限流处理流程: +// +// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游 +// ├─ 成功 → 正常返回 +// └─ 429/503 → handleSmartRetry +// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号 +// └─ retryDelay < 7s → 等待后重试 1 次 +// ├─ 成功 → 正常返回 +// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号 func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { + // 上游透传账号直接转发,不走 OAuth token 刷新 + if account.Type == AccountTypeUpstream { + return s.ForwardUpstream(ctx, c, account, body) + } + startTime := time.Now() + sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -983,11 +1304,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if mappedModel == "" { return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) } - loadModel := mappedModel // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 - thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) + billingModel := mappedModel // 获取 access_token if s.tokenProvider == nil { @@ -1022,11 +1342,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 action := "streamGenerateContent" - // 统计模型调用次数(包括粘性会话,用于负载均衡调度) - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel) - } - // 执行带重试的请求 result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{ ctx: ctx, @@ -1036,7 +1351,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, accessToken: accessToken, action: action, body: geminiBody, - quotaScope: quotaScope, c: c, httpUpstream: s.httpUpstream, settingService: s.settingService, @@ -1055,6 +1369,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ForceCacheBilling: switchErr.IsStickySession, } } + // 区分客户端取消和真正的上游失败,返回更准确的错误消息 + if c.Request.Context().Err() != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response") + } return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") } resp := result.resp @@ -1103,7 +1421,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, continue } - log.Printf("Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name) + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name) retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx)) if txErr != nil { @@ -1117,7 +1435,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, accessToken: accessToken, action: action, body: retryGeminiBody, - quotaScope: quotaScope, c: c, httpUpstream: s.httpUpstream, settingService: s.settingService, @@ -1137,7 +1454,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, Kind: "signature_retry_request_error", Message: sanitizeUpstreamErrorMessage(retryErr.Error()), }) - log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr) + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr) continue } @@ -1149,14 +1466,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, break } - retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 8<<10)) _ = retryResp.Body.Close() if retryResp.StatusCode == http.StatusTooManyRequests { retryBaseURL := "" if retryResp.Request != nil && retryResp.Request.URL != nil { retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host } - log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200)) } kind := "signature_retry" if strings.TrimSpace(stage.name) != "" { @@ -1209,7 +1526,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, upstreamDetail := s.getUpstreamErrorDetail(respBody) logBody, maxBytes := s.getLogConfig() if logBody { - log.Printf("%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes)) } appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, @@ -1228,7 +1545,28 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession) + + // 精确匹配服务端配置类 400 错误,触发同账号重试 + failover + if resp.StatusCode == http.StatusBadRequest { + msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) + if isGoogleProjectConfigError(msg) { + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) + upstreamDetail := s.getUpstreamErrorDetail(respBody) + log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true} + } + } if s.shouldFailoverUpstreamError(resp.StatusCode) { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) @@ -1258,20 +1596,22 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, var usage *ClaudeUsage var firstTokenMs *int + var clientDisconnect bool if claudeReq.Stream { // 客户端要求流式,直接透传转换 streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) if err != nil { - log.Printf("%s status=stream_error error=%v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err) return nil, err } usage = streamRes.usage firstTokenMs = streamRes.firstTokenMs + clientDisconnect = streamRes.clientDisconnect } else { // 客户端要求非流式,收集流式响应后转换返回 streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) if err != nil { - log.Printf("%s status=stream_collect_error error=%v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err) return nil, err } usage = streamRes.usage @@ -1279,12 +1619,13 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, // 使用原始模型用于计费和日志 - Stream: claudeReq.Stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Usage: *usage, + Model: billingModel, // 使用映射模型用于计费和日志 + Stream: claudeReq.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, }, nil } @@ -1583,8 +1924,19 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque } // ForwardGemini 转发 Gemini 协议请求 +// +// 限流处理流程: +// +// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游 +// ├─ 成功 → 正常返回 +// └─ 429/503 → handleSmartRetry +// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号 +// └─ retryDelay < 7s → 等待后重试 1 次 +// ├─ 成功 → 正常返回 +// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号 func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() + sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -1597,7 +1949,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if len(body) == 0 { return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") } - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) // 解析请求以获取 image_size(用于图片计费) imageSize := s.extractImageSize(body) @@ -1613,7 +1964,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co Usage: ClaudeUsage{}, Model: originalModel, Stream: false, - Duration: time.Since(time.Now()), + Duration: time.Since(startTime), FirstTokenMs: nil, }, nil default: @@ -1624,6 +1975,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if mappedModel == "" { return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) } + billingModel := mappedModel // 获取 access_token if s.tokenProvider == nil { @@ -1652,9 +2004,9 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 清理 Schema if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil { injectedBody = cleanedBody - log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name) } else { - log.Printf("[Antigravity] Failed to clean schema: %v", err) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Failed to clean schema: %v", err) } // 包装请求 @@ -1667,11 +2019,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 upstreamAction := "streamGenerateContent" - // 统计模型调用次数(包括粘性会话,用于负载均衡调度) - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) - } - // 执行带重试的请求 result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{ ctx: ctx, @@ -1681,7 +2028,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co accessToken: accessToken, action: upstreamAction, body: wrappedBody, - quotaScope: quotaScope, c: c, httpUpstream: s.httpUpstream, settingService: s.settingService, @@ -1700,6 +2046,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ForceCacheBilling: switchErr.IsStickySession, } } + // 区分客户端取消和真正的上游失败,返回更准确的错误消息 + if c.Request.Context().Err() != nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Client disconnected before upstream response") + } return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") } resp := result.resp @@ -1722,7 +2072,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co isModelNotFoundError(resp.StatusCode, respBody) { fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity) if fallbackModel != "" && fallbackModel != mappedModel { - log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody) if err == nil { @@ -1755,7 +2105,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if unwrapErr != nil || len(unwrappedForOps) == 0 { unwrappedForOps = respBody } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamDetail := s.getUpstreamErrorDetail(unwrappedForOps) @@ -1763,6 +2113,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // Always record upstream context for Ops error logs, even when we will failover. setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + // 精确匹配服务端配置类 400 错误,触发同账号重试 + failover + if resp.StatusCode == http.StatusBadRequest && isGoogleProjectConfigError(strings.ToLower(upstreamMsg)) { + log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps, RetryableOnSameAccount: true} + } + if s.shouldFailoverUpstreamError(resp.StatusCode) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, @@ -1789,7 +2155,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co Message: upstreamMsg, Detail: upstreamDetail, }) - log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500)) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500)) c.Data(resp.StatusCode, contentType, unwrappedForOps) return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } @@ -1802,21 +2168,23 @@ handleSuccess: var usage *ClaudeUsage var firstTokenMs *int + var clientDisconnect bool if stream { // 客户端要求流式,直接透传 streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime) if err != nil { - log.Printf("%s status=stream_error error=%v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err) return nil, err } usage = streamRes.usage firstTokenMs = streamRes.firstTokenMs + clientDisconnect = streamRes.clientDisconnect } else { // 客户端要求非流式,收集流式响应后返回 streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime) if err != nil { - log.Printf("%s status=stream_collect_error error=%v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err) return nil, err } usage = streamRes.usage @@ -1835,14 +2203,15 @@ handleSuccess: } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Usage: *usage, + Model: billingModel, + Stream: stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + ImageCount: imageCount, + ImageSize: imageSize, }, nil } @@ -1855,6 +2224,44 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) } } +// isGoogleProjectConfigError 判断(已提取的小写)错误消息是否属于 Google 服务端配置类问题。 +// 只精确匹配已知的服务端侧错误,避免对客户端请求错误做无意义重试。 +// 适用于所有走 Google 后端的平台(Antigravity、Gemini)。 +func isGoogleProjectConfigError(lowerMsg string) bool { + // Google 间歇性 Bug:Project ID 有效但被临时识别失败 + return strings.Contains(lowerMsg, "invalid project resource name") +} + +// googleConfigErrorCooldown 服务端配置类 400 错误的临时封禁时长 +const googleConfigErrorCooldown = 1 * time.Minute + +// tempUnscheduleGoogleConfigError 对服务端配置类 400 错误触发临时封禁, +// 避免短时间内反复调度到同一个有问题的账号。 +func tempUnscheduleGoogleConfigError(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) { + until := time.Now().Add(googleConfigErrorCooldown) + reason := "400: invalid project resource name (auto temp-unschedule 1m)" + if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil { + log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err) + } else { + log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason) + } +} + +// emptyResponseCooldown 空流式响应的临时封禁时长 +const emptyResponseCooldown = 1 * time.Minute + +// tempUnscheduleEmptyResponse 对空流式响应触发临时封禁, +// 避免短时间内反复调度到同一个返回空响应的账号。 +func tempUnscheduleEmptyResponse(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) { + until := time.Now().Add(emptyResponseCooldown) + reason := "empty stream response (auto temp-unschedule 1m)" + if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil { + log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err) + } else { + log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason) + } +} + // sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待 // 返回 true 表示正常完成等待,false 表示 context 已取消 func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { @@ -1871,14 +2278,22 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { sleepFor = 0 } + timer := time.NewTimer(sleepFor) select { case <-ctx.Done(): + timer.Stop() return false - case <-time.After(sleepFor): + case <-timer.C: return true } } +// isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记 +func isSingleAccountRetry(ctx context.Context) bool { + v, _ := SingleAccountRetryFromContext(ctx) + return v +} + // setModelRateLimitByModelName 使用官方模型 ID 设置模型级限流 // 直接使用上游返回的模型 ID(如 claude-sonnet-4-5)作为限流 key // 返回是否已成功设置(若模型名为空或 repo 为 nil 将返回 false) @@ -1888,13 +2303,13 @@ func setModelRateLimitByModelName(ctx context.Context, repo AccountRepository, a } // 直接使用官方模型 ID 作为 key,不再转换为 scope if err := repo.SetModelRateLimit(ctx, accountID, modelName, resetAt); err != nil { - log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err) return false } if afterSmartRetry { - log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) } else { - log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) } return true } @@ -1913,8 +2328,9 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) { // antigravitySmartRetryInfo 智能重试所需的信息 type antigravitySmartRetryInfo struct { - RetryDelay time.Duration // 重试延迟时间 - ModelName string // 限流的模型名称(如 "claude-sonnet-4-5") + RetryDelay time.Duration // 重试延迟时间 + ModelName string // 限流的模型名称(如 "claude-sonnet-4-5") + IsModelCapacityExhausted bool // 是否为模型容量不足(MODEL_CAPACITY_EXHAUSTED) } // parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息 @@ -2001,7 +2417,7 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo { // 例如: "0.5s", "10s", "4m50s", "1h30m", "200ms" 等 dur, err := time.ParseDuration(delay) if err != nil { - log.Printf("[Antigravity] failed to parse retryDelay: %s error=%v", delay, err) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] failed to parse retryDelay: %s error=%v", delay, err) continue } retryDelay = dur @@ -2029,31 +2445,40 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo { } return &antigravitySmartRetryInfo{ - RetryDelay: retryDelay, - ModelName: modelName, + RetryDelay: retryDelay, + ModelName: modelName, + IsModelCapacityExhausted: hasModelCapacityExhausted, } } // shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试 // 返回: -// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold) -// - shouldRateLimitModel: 是否应该限流模型(retryDelay >= antigravityRateLimitThreshold) -// - waitDuration: 等待时间(智能重试时使用,shouldRateLimitModel=true 时为 0) +// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold,或 MODEL_CAPACITY_EXHAUSTED) +// - shouldRateLimitModel: 是否应该限流模型并切换账号(仅 RATE_LIMIT_EXCEEDED 且 retryDelay >= 阈值) +// - waitDuration: 等待时间 // - modelName: 限流的模型名称 -func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string) { +// - isModelCapacityExhausted: 是否为模型容量不足(MODEL_CAPACITY_EXHAUSTED) +func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string, isModelCapacityExhausted bool) { if account.Platform != PlatformAntigravity { - return false, false, 0, "" + return false, false, 0, "", false } info := parseAntigravitySmartRetryInfo(respBody) if info == nil { - return false, false, 0, "" + return false, false, 0, "", false } + // MODEL_CAPACITY_EXHAUSTED(模型容量不足):所有账号共享同一模型容量池 + // 切换账号无意义,使用固定 1s 间隔重试 + if info.IsModelCapacityExhausted { + return true, false, antigravityModelCapacityRetryWait, info.ModelName, true + } + + // RATE_LIMIT_EXCEEDED(账号级限流): // retryDelay >= 阈值:直接限流模型,不重试 - // 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 5 分钟 + // 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 30s if info.RetryDelay >= antigravityRateLimitThreshold { - return false, true, 0, info.ModelName + return false, true, info.RetryDelay, info.ModelName, false } // retryDelay < 阈值:智能重试 @@ -2062,7 +2487,7 @@ func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shou waitDuration = antigravitySmartRetryMinWait } - return true, false, waitDuration, info.ModelName + return true, false, waitDuration, info.ModelName, false } // handleModelRateLimitParams 模型级限流处理参数 @@ -2088,8 +2513,9 @@ type handleModelRateLimitResult struct { // handleModelRateLimit 处理模型级限流(在原有逻辑之前调用) // 仅处理 429/503,解析模型名和 retryDelay -// - retryDelay < antigravityRateLimitThreshold: 返回 ShouldRetry=true,由调用方等待后重试 -// - retryDelay >= antigravityRateLimitThreshold: 设置模型限流 + 清除粘性会话 + 返回 SwitchError +// - MODEL_CAPACITY_EXHAUSTED: 返回 Handled=true(实际重试由 handleSmartRetry 处理) +// - RATE_LIMIT_EXCEEDED + retryDelay < 阈值: 返回 ShouldRetry=true,由调用方等待后重试 +// - RATE_LIMIT_EXCEEDED + retryDelay >= 阈值: 设置模型限流 + 清除粘性会话 + 返回 SwitchError func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult { if p.statusCode != 429 && p.statusCode != 503 { return &handleModelRateLimitResult{Handled: false} @@ -2100,9 +2526,19 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit return &handleModelRateLimitResult{Handled: false} } - // < antigravityRateLimitThreshold: 等待后重试 + // MODEL_CAPACITY_EXHAUSTED:模型容量不足,所有账号共享同一容量池 + // 切换账号无意义,不设置模型限流(实际重试由 handleSmartRetry 处理) + if info.IsModelCapacityExhausted { + log.Printf("%s status=%d model_capacity_exhausted model=%s (not switching account, retry handled by smart retry)", + p.prefix, p.statusCode, info.ModelName) + return &handleModelRateLimitResult{ + Handled: true, + } + } + + // RATE_LIMIT_EXCEEDED: < antigravityRateLimitThreshold: 等待后重试 if info.RetryDelay < antigravityRateLimitThreshold { - log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_wait model=%s wait=%v", p.prefix, p.statusCode, info.ModelName, info.RetryDelay) return &handleModelRateLimitResult{ Handled: true, @@ -2111,7 +2547,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit } } - // >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号 + // RATE_LIMIT_EXCEEDED: >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号 s.setModelRateLimitAndClearSession(p, info) return &handleModelRateLimitResult{ @@ -2127,12 +2563,12 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit // setModelRateLimitAndClearSession 设置模型限流并清除粘性会话 func (s *AntigravityGatewayService) setModelRateLimitAndClearSession(p *handleModelRateLimitParams, info *antigravitySmartRetryInfo) { resetAt := time.Now().Add(info.RetryDelay) - log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited model=%s account=%d reset_in=%v", p.prefix, p.statusCode, info.ModelName, p.account.ID, info.RetryDelay) // 设置模型限流状态(数据库) if err := s.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, info.ModelName, resetAt); err != nil { - log.Printf("%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err) } // 立即更新 Redis 快照中账号的限流状态,避免并发请求重复选中 @@ -2168,17 +2604,21 @@ func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx conte // 更新 Redis 快照 if err := s.schedulerSnapshot.UpdateAccountInCache(ctx, account); err != nil { - log.Printf("[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err) } } func (s *AntigravityGatewayService) handleUpstreamError( ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, - quotaScope AntigravityQuotaScope, + requestedModel string, groupID int64, sessionHash string, isStickySession bool, ) *handleModelRateLimitResult { - // ✨ 模型级限流处理(在原有逻辑之前) + // 遵守自定义错误码策略:未命中则跳过所有限流处理 + if !account.ShouldHandleErrorCode(statusCode) { + return nil + } + // 模型级限流处理(优先) result := s.handleModelRateLimit(&handleModelRateLimitParams{ ctx: ctx, prefix: prefix, @@ -2200,52 +2640,44 @@ func (s *AntigravityGatewayService) handleUpstreamError( return nil } - // ========== 原有逻辑,保持不变 ========== - // 429 使用 Gemini 格式解析(从 body 解析重置时间) + // 429:尝试解析模型级限流,解析失败时兜底为账号级限流 if statusCode == 429 { - // 调试日志遵循统一日志开关与长度限制,避免无条件记录完整上游响应体。 if logBody, maxBytes := s.getLogConfig(); logBody { - log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes)) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes)) } - useScopeLimit := quotaScope != "" resetAt := ParseGeminiRateLimitResetTime(body) - if resetAt == nil { - // 解析失败:使用默认限流时间(与临时限流保持一致) - // 可通过配置或环境变量覆盖 - defaultDur := antigravityDefaultRateLimitDuration - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 { - defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute - } - // 秒级环境变量优先级最高 - if override, ok := antigravityFallbackCooldownSeconds(); ok { - defaultDur = override - } - ra := time.Now().Add(defaultDur) - if useScopeLimit { - log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) - if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil { - log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) - } + defaultDur := s.getDefaultRateLimitDuration() + + // 尝试解析模型 key 并设置模型级限流 + // + // 注意:requestedModel 可能是"映射前"的请求模型名(例如 claude-opus-4-6), + // 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。 + // 因此这里必须写入最终模型 key,确保后续调度能正确避开已限流模型。 + modelKey := resolveFinalAntigravityModelKey(ctx, account, requestedModel) + if strings.TrimSpace(modelKey) == "" { + // 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过), + // 保持旧行为作为兜底,避免完全丢失模型级限流记录。 + modelKey = resolveAntigravityModelKey(requestedModel) + } + if modelKey != "" { + ra := s.resolveResetTime(resetAt, defaultDur) + if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil { + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err) } else { - log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur) - if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil { - log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) - } + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v", + prefix, modelKey, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second)) + s.updateAccountModelRateLimitInCache(ctx, account, modelKey, ra) } return nil } - resetTime := time.Unix(*resetAt, 0) - if useScopeLimit { - log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) - if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil { - log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) - } - } else { - log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) - if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil { - log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) - } + + // 无法解析模型 key,兜底为账号级限流 + ra := s.resolveResetTime(resetAt, defaultDur) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)", + prefix, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second)) + if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil { + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) } return nil } @@ -2255,14 +2687,95 @@ func (s *AntigravityGatewayService) handleUpstreamError( } shouldDisable := s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) if shouldDisable { - log.Printf("%s status=%d marked_error", prefix, statusCode) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d marked_error", prefix, statusCode) } return nil } +// getDefaultRateLimitDuration 获取默认限流时间 +func (s *AntigravityGatewayService) getDefaultRateLimitDuration() time.Duration { + defaultDur := antigravityDefaultRateLimitDuration + if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 { + defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute + } + if override, ok := antigravityFallbackCooldownSeconds(); ok { + defaultDur = override + } + return defaultDur +} + +// resolveResetTime 根据解析的重置时间或默认时长计算重置时间点 +func (s *AntigravityGatewayService) resolveResetTime(resetAt *int64, defaultDur time.Duration) time.Time { + if resetAt != nil { + return time.Unix(*resetAt, 0) + } + return time.Now().Add(defaultDur) +} + type antigravityStreamResult struct { - usage *ClaudeUsage - firstTokenMs *int + usage *ClaudeUsage + firstTokenMs *int + clientDisconnect bool // 客户端是否在流式传输过程中断开 +} + +// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。 +// 断开后所有写入操作变为 no-op,调用方通过 Disconnected() 判断是否继续 drain 上游。 +type antigravityClientWriter struct { + w gin.ResponseWriter + flusher http.Flusher + disconnected bool + prefix string // 日志前缀,标识来源方法 +} + +func newAntigravityClientWriter(w gin.ResponseWriter, flusher http.Flusher, prefix string) *antigravityClientWriter { + return &antigravityClientWriter{w: w, flusher: flusher, prefix: prefix} +} + +// Write 写入数据到客户端,写入失败时标记断开并返回 false +func (cw *antigravityClientWriter) Write(p []byte) bool { + if cw.disconnected { + return false + } + if _, err := cw.w.Write(p); err != nil { + cw.markDisconnected() + return false + } + cw.flusher.Flush() + return true +} + +// Fprintf 格式化写入数据到客户端,写入失败时标记断开并返回 false +func (cw *antigravityClientWriter) Fprintf(format string, args ...any) bool { + if cw.disconnected { + return false + } + if _, err := fmt.Fprintf(cw.w, format, args...); err != nil { + cw.markDisconnected() + return false + } + cw.flusher.Flush() + return true +} + +func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected } + +func (cw *antigravityClientWriter) markDisconnected() { + cw.disconnected = true + logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix) +} + +// handleStreamReadError 处理上游读取错误的通用逻辑。 +// 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。 +func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + logger.LegacyPrintf("service.antigravity_gateway", "Context canceled during streaming (%s), returning collected usage", prefix) + return true, true + } + if clientDisconnected { + logger.LegacyPrintf("service.antigravity_gateway", "Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err) + return true, true + } + return false, false } func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) { @@ -2288,7 +2801,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int @@ -2309,7 +2823,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2320,7 +2835,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -2338,10 +2853,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context intervalCh = intervalTicker.C } + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini") + // 仅发送一次错误事件,避免多次写入导致协议混乱 errorEventSent := false sendErrorEvent := func(reason string) { - if errorEventSent { + if errorEventSent || cw.Disconnected() { return } errorEventSent = true @@ -2353,11 +2870,14 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context select { case ev, ok := <-events: if !ok { - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil } if ev.err != nil { + if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity gemini"); handled { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } @@ -2370,11 +2890,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if strings.HasPrefix(trimmed, "data:") { payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) if payload == "" || payload == "[DONE]" { - if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + cw.Fprintf("%s\n", line) continue } @@ -2385,19 +2901,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context } // 解析 usage + if u := extractGeminiUsage(inner); u != nil { + usage = u + } var parsed map[string]any if json.Unmarshal(inner, &parsed) == nil { - if u := extractGeminiUsage(parsed); u != nil { - usage = u - } // Check for MALFORMED_FUNCTION_CALL if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 { if cand, ok := candidates[0].(map[string]any); ok { if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" { - log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream") + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream") if content, ok := cand["content"]; ok { if b, err := json.Marshal(content); err == nil { - log.Printf("[Antigravity] Malformed content: %s", string(b)) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Malformed content: %s", string(b)) } } } @@ -2410,27 +2926,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context firstTokenMs = &ms } - if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil { - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + cw.Fprintf("data: %s\n\n", payload) continue } - if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + cw.Fprintf("%s\n", line) case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) if time.Since(lastRead) < streamInterval { continue } - log.Printf("Stream data interval timeout (antigravity)") - // 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout + if cw.Disconnected() { + logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity gemini), returning collected usage") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)") sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } @@ -2445,7 +2956,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int @@ -2473,7 +2985,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2484,7 +2997,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -2511,7 +3024,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont } if ev.err != nil { if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err) } return nil, ev.err } @@ -2548,7 +3061,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont last = parsed // 提取 usage - if u := extractGeminiUsage(parsed); u != nil { + if u := extractGeminiUsage(inner); u != nil { usage = u } @@ -2556,10 +3069,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 { if cand, ok := candidates[0].(map[string]any); ok { if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" { - log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect") + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect") if content, ok := cand["content"]; ok { if b, err := json.Marshal(content); err == nil { - log.Printf("[Antigravity] Malformed content: %s", string(b)) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Malformed content: %s", string(b)) } } } @@ -2586,7 +3099,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if time.Since(lastRead) < streamInterval { continue } - log.Printf("Stream data interval timeout (antigravity non-stream)") + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity non-stream)") return nil, fmt.Errorf("stream data interval timeout") } } @@ -2595,9 +3108,14 @@ returnResponse: // 选择最后一个有效响应 finalResponse := pickGeminiCollectResult(last, lastWithParts) - // 处理空响应情况 + // 处理空响应情况 — 触发同账号重试 + failover 切换账号 if last == nil && lastWithParts == nil { - log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received") + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover") + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: []byte(`{"error":"empty stream response from upstream"}`), + RetryableOnSameAccount: true, + } } // 如果收集到了图片 parts,需要合并到最终响应中 @@ -2812,7 +3330,22 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou // 记录上游错误详情便于排障(可选:由配置控制;不回显到客户端) if logBody { - log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes)) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes)) + } + + // 检查错误透传规则 + if ptStatus, ptErrType, ptErrMsg, matched := applyErrorPassthroughRule( + c, account.Platform, upstreamStatus, body, + 0, "", "", + ); matched { + c.JSON(ptStatus, gin.H{ + "type": "error", + "error": gin.H{"type": ptErrType, "message": ptErrMsg}, + }) + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", upstreamStatus) + } + return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg) } var statusCode int @@ -2888,7 +3421,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) var firstTokenMs *int var last map[string]any @@ -2914,7 +3448,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2925,7 +3460,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -2952,7 +3487,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont } if ev.err != nil { if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err) } return nil, ev.err } @@ -3001,7 +3536,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if time.Since(lastRead) < streamInterval { continue } - log.Printf("Stream data interval timeout (antigravity claude non-stream)") + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity claude non-stream)") return nil, fmt.Errorf("stream data interval timeout") } } @@ -3010,10 +3545,14 @@ returnResponse: // 选择最后一个有效响应 finalResponse := pickGeminiCollectResult(last, lastWithParts) - // 处理空响应情况 + // 处理空响应情况 — 触发同账号重试 + failover 切换账号 if last == nil && lastWithParts == nil { - log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received") - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream") + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover") + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: []byte(`{"error":"empty stream response from upstream"}`), + RetryableOnSameAccount: true, + } } // 将收集的所有 parts 合并到最终响应中 @@ -3030,7 +3569,7 @@ returnResponse: // 转换 Gemini 响应为 Claude 格式 claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(geminiBody, originalModel) if err != nil { - log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody)) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody)) return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") } @@ -3068,7 +3607,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) // 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage { @@ -3100,7 +3640,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -3111,7 +3652,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) @@ -3128,10 +3669,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context intervalCh = intervalTicker.C } + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude") + // 仅发送一次错误事件,避免多次写入导致协议混乱 errorEventSent := false sendErrorEvent := func(reason string) { - if errorEventSent { + if errorEventSent || cw.Disconnected() { return } errorEventSent = true @@ -3139,21 +3682,29 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context flusher.Flush() } + // finishUsage 是获取 processor 最终 usage 的辅助函数 + finishUsage := func() *ClaudeUsage { + _, agUsage := processor.Finish() + return convertUsage(agUsage) + } + for { select { case ev, ok := <-events: if !ok { - // 发送结束事件 + // 上游完成,发送结束事件 finalEvents, agUsage := processor.Finish() if len(finalEvents) > 0 { - _, _ = c.Writer.Write(finalEvents) - flusher.Flush() + cw.Write(finalEvents) } - return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil + return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil } if ev.err != nil { + if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled { + return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err } @@ -3161,25 +3712,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context return nil, fmt.Errorf("stream read error: %w", ev.err) } - line := ev.line // 处理 SSE 行,转换为 Claude 格式 - claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n")) - + claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n")) if len(claudeEvents) > 0 { if firstTokenMs == nil { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - - if _, writeErr := c.Writer.Write(claudeEvents); writeErr != nil { - finalEvents, agUsage := processor.Finish() - if len(finalEvents) > 0 { - _, _ = c.Writer.Write(finalEvents) - } - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr - } - flusher.Flush() + cw.Write(claudeEvents) } case <-intervalCh: @@ -3187,13 +3727,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if time.Since(lastRead) < streamInterval { continue } - log.Printf("Stream data interval timeout (antigravity)") - // 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout + if cw.Disconnected() { + logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity claude), returning collected usage") + return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)") sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } } - } // extractImageSize 从 Gemini 请求中提取 image_size 参数 @@ -3214,14 +3756,17 @@ func (s *AntigravityGatewayService) extractImageSize(body []byte) string { } // isImageGenerationModel 判断模型是否为图片生成模型 -// 支持的模型:gemini-3-pro-image, gemini-3-pro-image-preview, gemini-2.5-flash-image 等 +// 支持的模型:gemini-3.1-flash-image, gemini-3-pro-image, gemini-2.5-flash-image 等 func isImageGenerationModel(model string) bool { modelLower := strings.ToLower(model) // 移除 models/ 前缀 modelLower = strings.TrimPrefix(modelLower, "models/") // 精确匹配或前缀匹配 - return modelLower == "gemini-3-pro-image" || + return modelLower == "gemini-3.1-flash-image" || + modelLower == "gemini-3.1-flash-image-preview" || + strings.HasPrefix(modelLower, "gemini-3.1-flash-image-") || + modelLower == "gemini-3-pro-image" || modelLower == "gemini-3-pro-image-preview" || strings.HasPrefix(modelLower, "gemini-3-pro-image-") || modelLower == "gemini-2.5-flash-image" || @@ -3332,3 +3877,305 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { payload["contents"] = filtered return json.Marshal(payload) } + +// ForwardUpstream 使用 base_url + /v1/messages + 双 header 认证透传上游 Claude 请求 +func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { + startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) + + // 获取上游配置 + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if baseURL == "" || apiKey == "" { + return nil, fmt.Errorf("upstream account missing base_url or api_key") + } + baseURL = strings.TrimSuffix(baseURL, "/") + + // 解析请求获取模型信息 + var claudeReq antigravity.ClaudeRequest + if err := json.Unmarshal(body, &claudeReq); err != nil { + return nil, fmt.Errorf("parse claude request: %w", err) + } + if strings.TrimSpace(claudeReq.Model) == "" { + return nil, fmt.Errorf("missing model") + } + originalModel := claudeReq.Model + + // 构建上游请求 URL + upstreamURL := baseURL + "/v1/messages" + + // 创建请求 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create upstream request: %w", err) + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) // Claude API 兼容 + + // 透传 Claude 相关 headers + if v := c.GetHeader("anthropic-version"); v != "" { + req.Header.Set("anthropic-version", v) + } + if v := c.GetHeader("anthropic-beta"); v != "" { + req.Header.Set("anthropic-beta", v) + } + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + logger.LegacyPrintf("service.antigravity_gateway", "%s upstream request failed: %v", prefix, err) + return nil, fmt.Errorf("upstream request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // 处理错误响应 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + // 429 错误时标记账号限流 + if resp.StatusCode == http.StatusTooManyRequests { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", false) + } + + // 透传上游错误 + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(resp.StatusCode) + _, _ = c.Writer.Write(respBody) + + return &ForwardResult{ + Model: originalModel, + }, nil + } + + // 处理成功响应(流式/非流式) + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + + if claudeReq.Stream { + // 流式响应:透传 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + + streamRes := s.streamUpstreamResponse(c, resp, startTime) + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + clientDisconnect = streamRes.clientDisconnect + } else { + // 非流式响应:直接透传 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read upstream response: %w", err) + } + + // 提取 usage + usage = s.extractClaudeUsage(respBody) + + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(http.StatusOK) + _, _ = c.Writer.Write(respBody) + } + + // 构建计费结果 + duration := time.Since(startTime) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds()) + + return &ForwardResult{ + Model: originalModel, + Stream: claudeReq.Stream, + Duration: duration, + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + Usage: ClaudeUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + }, + }, nil +} + +// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage +func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult { + usage := &ClaudeUsage{} + var firstTokenMs *int + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.settingService.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + streamInterval := time.Duration(0) + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream") + + for { + select { + case ev, ok := <-events: + if !ok { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()} + } + if ev.err != nil { + if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect} + } + logger.LegacyPrintf("service.antigravity_gateway", "Stream read error (antigravity upstream): %v", ev.err) + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} + } + + line := ev.line + + // 记录首 token 时间 + if firstTokenMs == nil && len(line) > 0 { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // 尝试从 message_delta 或 message_stop 事件提取 usage + s.extractSSEUsage(line, usage) + + // 透传行 + cw.Fprintf("%s\n", line) + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if cw.Disconnected() { + logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity upstream), returning collected usage") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true} + } + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} + } + } +} + +// extractSSEUsage 从 SSE data 行中提取 Claude usage(用于流式透传场景) +func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUsage) { + if !strings.HasPrefix(line, "data: ") { + return + } + dataStr := strings.TrimPrefix(line, "data: ") + var event map[string]any + if json.Unmarshal([]byte(dataStr), &event) != nil { + return + } + u, ok := event["usage"].(map[string]any) + if !ok { + return + } + if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheCreationInputTokens = int(v) + } + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + if cc, ok := u["cache_creation"].(map[string]any); ok { + if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok { + usage.CacheCreation5mTokens = int(v) + } + if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok { + usage.CacheCreation1hTokens = int(v) + } + } +} + +// extractClaudeUsage 从非流式 Claude 响应提取 usage +func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage { + usage := &ClaudeUsage{} + var resp map[string]any + if json.Unmarshal(body, &resp) != nil { + return usage + } + if u, ok := resp["usage"].(map[string]any); ok { + if v, ok := u["input_tokens"].(float64); ok { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok { + usage.CacheCreationInputTokens = int(v) + } + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + if cc, ok := u["cache_creation"].(map[string]any); ok { + if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok { + usage.CacheCreation5mTokens = int(v) + } + if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok { + usage.CacheCreation1hTokens = int(v) + } + } + } + return usage +} diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 91cefc28..84b65adc 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -4,17 +4,43 @@ import ( "bytes" "context" "encoding/json" + "errors" + "fmt" "io" "net/http" "net/http/httptest" + "strings" "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) +// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter +type antigravityFailingWriter struct { + gin.ResponseWriter + failAfter int // 允许成功写入的次数,之后所有写入返回错误 + writes int +} + +func (w *antigravityFailingWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + +// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService +func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService { + return &AntigravityGatewayService{ + settingService: &SettingService{cfg: cfg}, + } +} + func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) { req := &antigravity.ClaudeRequest{ Model: "claude-sonnet-4-5", @@ -108,6 +134,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, return s.resp, s.err } +type antigravitySettingRepoStub struct{} + +func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + return "", ErrSettingNotFound +} + +func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { gin.SetMode(gin.TestMode) writer := httptest.NewRecorder() @@ -134,8 +190,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { } svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: &httpUpstreamStub{resp: resp}, + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, } account := &Account{ @@ -337,8 +394,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } -// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling -// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true +// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies +// that ForwardGemini sets ForceCacheBilling=true for sticky session switch. func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) { gin.SetMode(gin.TestMode) writer := httptest.NewRecorder() @@ -391,3 +448,793 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling( require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode) require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } + +// TestAntigravityGatewayService_Forward_BillsWithMappedModel +// 验证:Antigravity Claude 转发返回的计费模型使用映射后的模型 +func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "hello"}, + }, + "max_tokens": 16, + "stream": true, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request = req + + upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"X-Request-Id": []string{"req-bill-1"}}, + Body: io.NopCloser(bytes.NewReader(upstreamBody)), + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, + } + + const mappedModel = "gemini-3-pro-high" + account := &Account{ + ID: 5, + Name: "acc-forward-billing", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + "model_mapping": map[string]any{ + "claude-sonnet-4-5": mappedModel, + }, + }, + } + + result, err := svc.Forward(context.Background(), c, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, mappedModel, result.Model) +} + +// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel +// 验证:Antigravity Gemini 转发返回的计费模型使用映射后的模型 +func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hello"}}}, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + c.Request = req + + upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"X-Request-Id": []string{"req-bill-2"}}, + Body: io.NopCloser(bytes.NewReader(upstreamBody)), + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, + } + + const mappedModel = "gemini-3-pro-high" + account := &Account{ + ID: 6, + Name: "acc-gemini-billing", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + "model_mapping": map[string]any{ + "gemini-2.5-flash": mappedModel, + }, + }, + } + + result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, mappedModel, result.Model) +} + +// TestStreamUpstreamResponse_UsageAndFirstToken +// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间 +func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"usage":{"input_tokens":1,"output_tokens":2,"cache_read_input_tokens":3,"cache_creation_input_tokens":4}}`) + fmt.Fprintln(pw, `data: {"usage":{"output_tokens":5}}`) + }() + + start := time.Now().Add(-10 * time.Millisecond) + result := svc.streamUpstreamResponse(c, resp, start) + _ = pr.Close() + + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 1, result.usage.InputTokens) + // 第二次事件覆盖 output_tokens + require.Equal(t, 5, result.usage.OutputTokens) + require.Equal(t, 3, result.usage.CacheReadInputTokens) + require.Equal(t, 4, result.usage.CacheCreationInputTokens) + require.NotNil(t, result.firstTokenMs) + + // 确保有透传输出 + require.Contains(t, rec.Body.String(), "data:") +} + +// --- 流式 happy path 测试 --- + +// TestStreamUpstreamResponse_NormalComplete +// 验证:正常流式转发完成时,数据正确透传、usage 正确收集、clientDisconnect=false +func TestStreamUpstreamResponse_NormalComplete(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `event: message_start`) + fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `event: content_block_delta`) + fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `event: message_delta`) + fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`) + fmt.Fprintln(pw, "") + }() + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NotNil(t, result) + require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect") + require.NotNil(t, result.usage) + require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta") + require.NotNil(t, result.firstTokenMs, "should record first token time") + + // 验证数据被透传到客户端 + body := rec.Body.String() + require.Contains(t, body, "event: message_start") + require.Contains(t, body, "content_block_delta") + require.Contains(t, body, "message_delta") +} + +// TestHandleGeminiStreamingResponse_NormalComplete +// 验证:正常 Gemini 流式转发,数据正确透传、usage 正确收集 +func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // 第一个 chunk(部分内容) + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`) + fmt.Fprintln(pw, "") + // 第二个 chunk(最终内容+完整 usage) + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect") + require.NotNil(t, result.usage) + // Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2 + // → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2 + require.Equal(t, 8, result.usage.InputTokens) + require.Equal(t, 8, result.usage.OutputTokens) + require.Equal(t, 2, result.usage.CacheReadInputTokens) + require.NotNil(t, result.firstTokenMs, "should record first token time") + + // 验证数据被透传到客户端 + body := rec.Body.String() + require.Contains(t, body, "Hello") + require.Contains(t, body, "world") + // 不应包含错误事件 + require.NotContains(t, body, "event: error") +} + +// TestHandleClaudeStreamingResponse_NormalComplete +// 验证:正常 Claude 流式转发(Gemini→Claude 转换),数据正确转换并输出 +func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // v1internal 包装格式:Gemini 数据嵌套在 "response" 字段下 + // ProcessLine 先尝试反序列化为 V1InternalResponse,裸格式会导致 Response.UsageMetadata 为空 + fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect") + require.NotNil(t, result.usage) + // Gemini→Claude 转换的 usage:promptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3 + require.Equal(t, 5, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + require.NotNil(t, result.firstTokenMs, "should record first token time") + + // 验证输出是 Claude SSE 格式(processor 会转换) + body := rec.Body.String() + require.Contains(t, body, "event: message_start", "should contain Claude message_start event") + require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event") + // 不应包含错误事件 + require.NotContains(t, body, "event: error") +} + +// TestHandleGeminiStreamingResponse_ThoughtsTokenCount +// 验证:Gemini 流式转发时 thoughtsTokenCount 被计入 OutputTokens +func TestHandleGeminiStreamingResponse_ThoughtsTokenCount(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"thoughtsTokenCount":50}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":30,"thoughtsTokenCount":80,"cachedContentTokenCount":10}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + // promptTokenCount=100, cachedContentTokenCount=10 → InputTokens=90 + require.Equal(t, 90, result.usage.InputTokens) + // candidatesTokenCount=30 + thoughtsTokenCount=80 → OutputTokens=110 + require.Equal(t, 110, result.usage.OutputTokens) + require.Equal(t, 10, result.usage.CacheReadInputTokens) +} + +// TestHandleClaudeStreamingResponse_ThoughtsTokenCount +// 验证:Gemini→Claude 流式转换时 thoughtsTokenCount 被计入 OutputTokens +func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":50,"candidatesTokenCount":10,"thoughtsTokenCount":25}}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro") + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + // promptTokenCount=50 → InputTokens=50 + require.Equal(t, 50, result.usage.InputTokens) + // candidatesTokenCount=10 + thoughtsTokenCount=25 → OutputTokens=35 + require.Equal(t, 35, result.usage.OutputTokens) +} + +// --- 流式客户端断开检测测试 --- + +// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage +// 验证:客户端写入失败后,streamUpstreamResponse 继续读取上游以收集 usage +func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `event: message_start`) + fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `event: message_delta`) + fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`) + fmt.Fprintln(pw, "") + }() + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotNil(t, result.usage) + require.Equal(t, 20, result.usage.OutputTokens) +} + +// TestStreamUpstreamResponse_ContextCanceled +// 验证:context 取消时返回 usage 且标记 clientDisconnect +func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "event: error") +} + +// TestStreamUpstreamResponse_Timeout +// 验证:上游超时时返回已收集的 usage +func TestStreamUpstreamResponse_Timeout(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pw.Close() + _ = pr.Close() + + require.NotNil(t, result) + require.False(t, result.clientDisconnect) +} + +// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect +// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect +func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`) + fmt.Fprintln(pw, "") + // 不关闭 pw → 等待超时 + }() + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pw.Close() + _ = pr.Close() + + require.NotNil(t, result) + require.True(t, result.clientDisconnect) +} + +// TestHandleGeminiStreamingResponse_ClientDisconnect +// 验证:Gemini 流式转发中客户端断开后继续 drain 上游 +func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "write_failed") +} + +// TestHandleGeminiStreamingResponse_ContextCanceled +// 验证:context 取消时不注入错误事件 +func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "event: error") +} + +// TestHandleClaudeStreamingResponse_ClientDisconnect +// 验证:Claude 流式转发中客户端断开后继续 drain 上游 +func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // v1internal 包装格式 + fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) +} + +// TestHandleClaudeStreamingResponse_ContextCanceled +// 验证:context 取消时不注入错误事件 +func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "event: error") +} + +// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage +func TestExtractSSEUsage(t *testing.T) { + svc := &AntigravityGatewayService{} + tests := []struct { + name string + line string + expected ClaudeUsage + }{ + { + name: "message_delta with output_tokens", + line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`, + expected: ClaudeUsage{OutputTokens: 42}, + }, + { + name: "non-data line ignored", + line: `event: message_start`, + expected: ClaudeUsage{}, + }, + { + name: "top-level usage with all fields", + line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`, + expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + usage := &ClaudeUsage{} + svc.extractSSEUsage(tt.line, usage) + require.Equal(t, tt.expected, *usage) + }) + } +} + +// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测 +func TestAntigravityClientWriter(t *testing.T) { + t.Run("normal write succeeds", func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(c.Writer, flusher, "test") + + ok := cw.Write([]byte("hello")) + require.True(t, ok) + require.False(t, cw.Disconnected()) + require.Contains(t, rec.Body.String(), "hello") + }) + + t.Run("write failure marks disconnected", func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(fw, flusher, "test") + + ok := cw.Write([]byte("hello")) + require.False(t, ok) + require.True(t, cw.Disconnected()) + }) + + t.Run("subsequent writes are no-op", func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(fw, flusher, "test") + + cw.Write([]byte("first")) + ok := cw.Fprintf("second %d", 2) + require.False(t, ok) + require.True(t, cw.Disconnected()) + }) +} + +// TestUnwrapV1InternalResponse 测试 unwrapV1InternalResponse 的各种输入场景 +func TestUnwrapV1InternalResponse(t *testing.T) { + svc := &AntigravityGatewayService{} + + // 构造 >50KB 的大型 JSON + largePadding := strings.Repeat("x", 50*1024) + largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding)) + largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding) + + tests := []struct { + name string + input []byte + expected string + wantErr bool + }{ + { + name: "正常 response 包装", + input: []byte(`{"response":{"id":"123","content":"hello"}}`), + expected: `{"id":"123","content":"hello"}`, + }, + { + name: "无 response 透传", + input: []byte(`{"id":"456"}`), + expected: `{"id":"456"}`, + }, + { + name: "空 JSON", + input: []byte(`{}`), + expected: `{}`, + }, + { + name: "response 为 null", + input: []byte(`{"response":null}`), + expected: `null`, + }, + { + name: "response 为基础类型 string", + input: []byte(`{"response":"hello"}`), + expected: `"hello"`, + }, + { + name: "非法 JSON", + input: []byte(`not json`), + expected: `not json`, + }, + { + name: "嵌套 response 只解一层", + input: []byte(`{"response":{"response":{"inner":true}}}`), + expected: `{"response":{"inner":true}}`, + }, + { + name: "大型 JSON >50KB", + input: largeInput, + expected: largeExpected, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := svc.unwrapV1InternalResponse(tt.input) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.expected, strings.TrimSpace(string(got))) + }) + } +} + +// --- unwrapV1InternalResponse benchmark 对照组 --- + +// unwrapV1InternalResponseOld 旧实现:Unmarshal+Marshal 双重开销(仅用于 benchmark 对照) +func unwrapV1InternalResponseOld(body []byte) ([]byte, error) { + var outer map[string]any + if err := json.Unmarshal(body, &outer); err != nil { + return nil, err + } + if resp, ok := outer["response"]; ok { + return json.Marshal(resp) + } + return body, nil +} + +func BenchmarkUnwrapV1Internal_Old_Small(b *testing.B) { + body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = unwrapV1InternalResponseOld(body) + } +} + +func BenchmarkUnwrapV1Internal_New_Small(b *testing.B) { + body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`) + svc := &AntigravityGatewayService{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = svc.unwrapV1InternalResponse(body) + } +} + +func BenchmarkUnwrapV1Internal_Old_Large(b *testing.B) { + body := generateLargeUnwrapJSON(10 * 1024) // ~10KB + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = unwrapV1InternalResponseOld(body) + } +} + +func BenchmarkUnwrapV1Internal_New_Large(b *testing.B) { + body := generateLargeUnwrapJSON(10 * 1024) // ~10KB + svc := &AntigravityGatewayService{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = svc.unwrapV1InternalResponse(body) + } +} + +// generateLargeUnwrapJSON 生成指定最小大小的包含 response 包装的 JSON +func generateLargeUnwrapJSON(minSize int) []byte { + parts := make([]map[string]string, 0) + current := 0 + for current < minSize { + text := fmt.Sprintf("这是第 %d 段内容,用于填充 JSON 到目标大小。", len(parts)+1) + parts = append(parts, map[string]string{"text": text}) + current += len(text) + 20 // 估算 JSON 编码开销 + } + inner := map[string]any{ + "candidates": []map[string]any{ + {"content": map[string]any{"parts": parts}}, + }, + "usageMetadata": map[string]any{ + "promptTokenCount": 100, + "candidatesTokenCount": 50, + }, + } + outer := map[string]any{"response": inner} + b, _ := json.Marshal(outer) + return b +} diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index f3621555..71939d26 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -76,6 +76,12 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { }, // 3. 默认映射中的透传(映射到自己) + { + name: "默认映射透传 - claude-sonnet-4-6", + requestedModel: "claude-sonnet-4-6", + accountMapping: nil, + expected: "claude-sonnet-4-6", + }, { name: "默认映射透传 - claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5", diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index fa8379ed..5f6691be 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -112,7 +112,10 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig } } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } // 交换 token tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier) @@ -167,7 +170,10 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken time.Sleep(backoff) } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } tokenResp, err := client.RefreshToken(ctx, refreshToken) if err == nil { now := time.Now() @@ -192,6 +198,46 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr) } +// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id) +func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64) (*AntigravityTokenInfo, error) { + var proxyURL string + if proxyID != nil { + proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) + if err == nil && proxy != nil { + proxyURL = proxy.URL() + } + } + + // 刷新 token + tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL) + if err != nil { + return nil, err + } + + // 获取用户信息(email) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } + userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken) + if err != nil { + fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err) + } else { + tokenInfo.Email = userInfo.Email + } + + // 获取 project_id(容错,失败不阻塞) + projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3) + if loadErr != nil { + fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr) + tokenInfo.ProjectIDMissing = true + } else { + tokenInfo.ProjectID = projectID + } + + return tokenInfo, nil +} + func isNonRetryableAntigravityOAuthError(err error) bool { msg := err.Error() nonRetryable := []string{ @@ -272,13 +318,25 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac time.Sleep(backoff) } - client := antigravity.NewClient(proxyURL) - loadResp, _, err := client.LoadCodeAssist(ctx, accessToken) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return "", fmt.Errorf("create antigravity client failed: %w", err) + } + loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken) if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { return loadResp.CloudAICompanionProject, nil } + if err == nil { + if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" { + return projectID, nil + } else if onboardErr != nil { + lastErr = onboardErr + continue + } + } + // 记录错误 if err != nil { lastErr = err @@ -292,6 +350,65 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr) } +func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) { + tierID := resolveDefaultTierID(loadRaw) + if tierID == "" { + return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier") + } + + projectID, err := client.OnboardUser(ctx, accessToken, tierID) + if err != nil { + return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err) + } + return projectID, nil +} + +func resolveDefaultTierID(loadRaw map[string]any) string { + if len(loadRaw) == 0 { + return "" + } + + rawTiers, ok := loadRaw["allowedTiers"] + if !ok { + return "" + } + + tiers, ok := rawTiers.([]any) + if !ok { + return "" + } + + for _, rawTier := range tiers { + tier, ok := rawTier.(map[string]any) + if !ok { + continue + } + if isDefault, _ := tier["isDefault"].(bool); !isDefault { + continue + } + if id, ok := tier["id"].(string); ok { + id = strings.TrimSpace(id) + if id != "" { + return id + } + } + } + + return "" +} + +// FillProjectID 仅获取 project_id,不刷新 OAuth token +func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Account, accessToken string) (string, error) { + var proxyURL string + if account.ProxyID != nil { + proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID) + if err == nil && proxy != nil { + proxyURL = proxy.URL() + } + } + return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3) +} + // BuildAccountCredentials 构建账户凭证 func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any { creds := map[string]any{ diff --git a/backend/internal/service/antigravity_oauth_service_test.go b/backend/internal/service/antigravity_oauth_service_test.go new file mode 100644 index 00000000..1d2d8235 --- /dev/null +++ b/backend/internal/service/antigravity_oauth_service_test.go @@ -0,0 +1,82 @@ +package service + +import ( + "testing" +) + +func TestResolveDefaultTierID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + loadRaw map[string]any + want string + }{ + { + name: "nil loadRaw", + loadRaw: nil, + want: "", + }, + { + name: "missing allowedTiers", + loadRaw: map[string]any{ + "paidTier": map[string]any{"id": "g1-pro-tier"}, + }, + want: "", + }, + { + name: "empty allowedTiers", + loadRaw: map[string]any{"allowedTiers": []any{}}, + want: "", + }, + { + name: "tier missing id field", + loadRaw: map[string]any{ + "allowedTiers": []any{ + map[string]any{"isDefault": true}, + }, + }, + want: "", + }, + { + name: "allowedTiers but no default", + loadRaw: map[string]any{ + "allowedTiers": []any{ + map[string]any{"id": "free-tier", "isDefault": false}, + map[string]any{"id": "standard-tier", "isDefault": false}, + }, + }, + want: "", + }, + { + name: "default tier found", + loadRaw: map[string]any{ + "allowedTiers": []any{ + map[string]any{"id": "free-tier", "isDefault": true}, + map[string]any{"id": "standard-tier", "isDefault": false}, + }, + }, + want: "free-tier", + }, + { + name: "default tier id with spaces", + loadRaw: map[string]any{ + "allowedTiers": []any{ + map[string]any{"id": " standard-tier ", "isDefault": true}, + }, + }, + want: "standard-tier", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := resolveDefaultTierID(tc.loadRaw) + if got != tc.want { + t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want) + } + }) + } +} diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go index 07eb563d..e950ec1d 100644 --- a/backend/internal/service/antigravity_quota_fetcher.go +++ b/backend/internal/service/antigravity_quota_fetcher.go @@ -2,6 +2,7 @@ package service import ( "context" + "fmt" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" @@ -31,7 +32,10 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou accessToken := account.GetCredential("access_token") projectID := account.GetCredential("project_id") - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } // 调用 API 获取配额 modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) diff --git a/backend/internal/service/antigravity_quota_scope.go b/backend/internal/service/antigravity_quota_scope.go index 43ac6c2f..e181e7f8 100644 --- a/backend/internal/service/antigravity_quota_scope.go +++ b/backend/internal/service/antigravity_quota_scope.go @@ -2,63 +2,23 @@ package service import ( "context" - "slices" "strings" "time" ) -const antigravityQuotaScopesKey = "antigravity_quota_scopes" - -// AntigravityQuotaScope 表示 Antigravity 的配额域 -type AntigravityQuotaScope string - -const ( - AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude" - AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text" - AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image" -) - -// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中 -func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool { - if len(supportedScopes) == 0 { - // 未配置时默认全部支持 - return true - } - supported := slices.Contains(supportedScopes, string(scope)) - return supported -} - -// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本) -func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { - return resolveAntigravityQuotaScope(requestedModel) -} - -// resolveAntigravityQuotaScope 根据模型名称解析配额域 -func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { - model := normalizeAntigravityModelName(requestedModel) - if model == "" { - return "", false - } - switch { - case strings.HasPrefix(model, "claude-"): - return AntigravityQuotaScopeClaude, true - case strings.HasPrefix(model, "gemini-"): - if isImageGenerationModel(model) { - return AntigravityQuotaScopeGeminiImage, true - } - return AntigravityQuotaScopeGeminiText, true - default: - return "", false - } -} - func normalizeAntigravityModelName(model string) string { normalized := strings.ToLower(strings.TrimSpace(model)) normalized = strings.TrimPrefix(normalized, "models/") return normalized } -// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。 +// resolveAntigravityModelKey 根据请求的模型名解析限流 key +// 返回空字符串表示无法解析 +func resolveAntigravityModelKey(requestedModel string) string { + return normalizeAntigravityModelName(requestedModel) +} + +// IsSchedulableForModel 结合模型级限流判断是否可调度。 // 保持旧签名以兼容既有调用方;默认使用 context.Background()。 func (a *Account) IsSchedulableForModel(requestedModel string) bool { return a.IsSchedulableForModelWithContext(context.Background(), requestedModel) @@ -74,107 +34,20 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste if a.isModelRateLimitedWithContext(ctx, requestedModel) { return false } - if a.Platform != PlatformAntigravity { - return true - } - scope, ok := resolveAntigravityQuotaScope(requestedModel) - if !ok { - return true - } - resetAt := a.antigravityQuotaScopeResetAt(scope) - if resetAt == nil { - return true - } - now := time.Now() - return !now.Before(*resetAt) + return true } -func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time { - if a == nil || a.Extra == nil || scope == "" { - return nil - } - rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any) - if !ok { - return nil - } - rawScope, ok := rawScopes[string(scope)].(map[string]any) - if !ok { - return nil - } - resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string) - if !ok || strings.TrimSpace(resetAtRaw) == "" { - return nil - } - resetAt, err := time.Parse(time.RFC3339, resetAtRaw) - if err != nil { - return nil - } - return &resetAt -} - -var antigravityAllScopes = []AntigravityQuotaScope{ - AntigravityQuotaScopeClaude, - AntigravityQuotaScopeGeminiText, - AntigravityQuotaScopeGeminiImage, -} - -func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 { - if a == nil || a.Platform != PlatformAntigravity { - return nil - } - now := time.Now() - result := make(map[string]int64) - for _, scope := range antigravityAllScopes { - resetAt := a.antigravityQuotaScopeResetAt(scope) - if resetAt != nil && now.Before(*resetAt) { - remainingSec := int64(time.Until(*resetAt).Seconds()) - if remainingSec > 0 { - result[string(scope)] = remainingSec - } - } - } - if len(result) == 0 { - return nil - } - return result -} - -// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间 -// 返回 0 表示未限流或已过期 -func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration { - if a == nil || a.Platform != PlatformAntigravity { - return 0 - } - scope, ok := resolveAntigravityQuotaScope(requestedModel) - if !ok { - return 0 - } - resetAt := a.antigravityQuotaScopeResetAt(scope) - if resetAt == nil { - return 0 - } - if remaining := time.Until(*resetAt); remaining > 0 { - return remaining - } - return 0 -} - -// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值) +// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流) // 返回 0 表示未限流或已过期 func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration { return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel) } -// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值) +// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型级限流) // 返回 0 表示未限流或已过期 func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration { if a == nil { return 0 } - modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel) - scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel) - if modelRemaining > scopeRemaining { - return modelRemaining - } - return scopeRemaining + return a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel) } diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index 20936356..dd8dd83f 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -15,6 +15,12 @@ import ( "github.com/stretchr/testify/require" ) +// 编译期接口断言 +var _ HTTPUpstream = (*stubAntigravityUpstream)(nil) +var _ HTTPUpstream = (*recordingOKUpstream)(nil) +var _ AccountRepository = (*stubAntigravityAccountRepo)(nil) +var _ SchedulerCache = (*stubSchedulerCache)(nil) + type stubAntigravityUpstream struct { firstBase string secondBase string @@ -59,12 +65,6 @@ func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, return s.Do(req, proxyURL, accountID, accountConcurrency) } -type scopeLimitCall struct { - accountID int64 - scope AntigravityQuotaScope - resetAt time.Time -} - type rateLimitCall struct { accountID int64 resetAt time.Time @@ -78,16 +78,10 @@ type modelRateLimitCall struct { type stubAntigravityAccountRepo struct { AccountRepository - scopeCalls []scopeLimitCall rateCalls []rateLimitCall modelRateLimitCalls []modelRateLimitCall } -func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { - s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt}) - return nil -} - func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt}) return nil @@ -98,7 +92,9 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i return nil } -func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { +func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) { + t.Setenv(antigravityForwardBaseURLEnv, "") + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) oldAvailability := antigravity.DefaultURLAvailability defer func() { @@ -131,10 +127,9 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { accessToken: "token", action: "generateContent", body: []byte(`{"input":"test"}`), - quotaScope: AntigravityQuotaScopeClaude, httpUpstream: upstream, requestedModel: "claude-sonnet-4-5", - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { handleErrorCalled = true return nil }, @@ -144,32 +139,16 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { require.NotNil(t, result) require.NotNil(t, result.resp) defer func() { _ = result.resp.Body.Close() }() - require.Equal(t, http.StatusOK, result.resp.StatusCode) - require.False(t, handleErrorCalled) - require.Len(t, upstream.calls, 2) - require.True(t, strings.HasPrefix(upstream.calls[0], base1)) - require.True(t, strings.HasPrefix(upstream.calls[1], base2)) + require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode) + require.True(t, handleErrorCalled) + require.Len(t, upstream.calls, antigravityMaxRetries) + for _, callURL := range upstream.calls { + require.True(t, strings.HasPrefix(callURL, base1)) + } available := antigravity.DefaultURLAvailability.GetAvailableURLs() require.NotEmpty(t, available) - require.Equal(t, base2, available[0]) -} - -func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) { - // 分区限流始终开启,不再支持通过环境变量关闭 - repo := &stubAntigravityAccountRepo{} - svc := &AntigravityGatewayService{accountRepo: repo} - account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity} - - body := buildGeminiRateLimitBody("3s") - svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false) - - require.Len(t, repo.scopeCalls, 1) - require.Empty(t, repo.rateCalls) - call := repo.scopeCalls[0] - require.Equal(t, account.ID, call.accountID) - require.Equal(t, AntigravityQuotaScopeClaude, call.scope) - require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second) + require.Equal(t, base1, available[0]) } // TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景 @@ -189,7 +168,7 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) { } }`) - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false) // 应该触发模型限流 require.NotNil(t, result) @@ -200,31 +179,48 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) { require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } -// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流) +// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走模型级限流兜底) func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) { repo := &stubAntigravityAccountRepo{} svc := &AntigravityGatewayService{accountRepo: repo} account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity} - // 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ scope 限流 + // 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ 走模型级限流兜底 body := buildGeminiRateLimitBody("5s") - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false) - // 不应该触发模型限流,应该走 scope 限流 + // handleModelRateLimit 不会处理(因为没有 RATE_LIMIT_EXCEEDED), + // 但 429 兜底逻辑会使用 requestedModel 设置模型级限流 require.Nil(t, result) - require.Empty(t, repo.modelRateLimitCalls) - require.Len(t, repo.scopeCalls, 1) - require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope) + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } -// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景 -func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) { +// TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey 测试 429 非模型限流场景 +// 验证:requestedModel 会被映射到 Antigravity 最终模型(例如 claude-opus-4-6 -> claude-opus-4-6-thinking) +func TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 20, Name: "acc-20", Platform: PlatformAntigravity} + + body := buildGeminiRateLimitBody("5s") + + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-opus-4-6", 0, "", false) + + require.Nil(t, result) + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-opus-4-6-thinking", repo.modelRateLimitCalls[0].modelKey) +} + +// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景 +// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号 +func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) { repo := &stubAntigravityAccountRepo{} svc := &AntigravityGatewayService{accountRepo: repo} account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity} - // 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流 + // 503 + MODEL_CAPACITY_EXHAUSTED → 等待重试,不切换账号 body := []byte(`{ "error": { "status": "UNAVAILABLE", @@ -235,15 +231,15 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) { } }`) - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false) - // 应该触发模型限流 + // MODEL_CAPACITY_EXHAUSTED 应该标记为已处理,不切换账号,不设置模型限流 + // 实际重试由 handleSmartRetry 处理 require.NotNil(t, result) require.True(t, result.Handled) - require.NotNil(t, result.SwitchError) - require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel) - require.Len(t, repo.modelRateLimitCalls, 1) - require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey) + require.False(t, result.ShouldRetry, "MODEL_CAPACITY_EXHAUSTED should not trigger retry from handleModelRateLimit path") + require.Nil(t, result.SwitchError, "MODEL_CAPACITY_EXHAUSTED should not trigger account switch") + require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit") } // TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理) @@ -263,12 +259,11 @@ func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) { } }`) - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false) // 503 非模型限流不应该做任何处理 require.Nil(t, result) require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit") - require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit") require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit") } @@ -281,12 +276,11 @@ func TestHandleUpstreamError_503_EmptyBody(t *testing.T) { // 503 + 空响应体 → 不做任何处理 body := []byte(`{}`) - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false) // 503 空响应不应该做任何处理 require.Nil(t, result) require.Empty(t, repo.modelRateLimitCalls) - require.Empty(t, repo.scopeCalls) require.Empty(t, repo.rateCalls) } @@ -307,15 +301,7 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) { require.False(t, account.IsSchedulableForModel("gemini-3-flash")) account.RateLimitResetAt = nil - account.Extra = map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future.Format(time.RFC3339), - }, - }, - } - - require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5")) + require.True(t, account.IsSchedulableForModel("claude-sonnet-4-5")) require.True(t, account.IsSchedulableForModel("gemini-3-flash")) } @@ -341,11 +327,12 @@ func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) { func TestParseAntigravitySmartRetryInfo(t *testing.T) { tests := []struct { - name string - body string - expectedDelay time.Duration - expectedModel string - expectedNil bool + name string + body string + expectedDelay time.Duration + expectedModel string + expectedNil bool + expectedIsModelCapacityExhausted bool }{ { name: "valid complete response with RATE_LIMIT_EXCEEDED", @@ -408,8 +395,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) { "message": "No capacity available for model gemini-3-pro-high on the server" } }`, - expectedDelay: 39 * time.Second, - expectedModel: "gemini-3-pro-high", + expectedDelay: 39 * time.Second, + expectedModel: "gemini-3-pro-high", + expectedIsModelCapacityExhausted: true, }, { name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil", @@ -520,6 +508,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) { if result.ModelName != tt.expectedModel { t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel) } + if result.IsModelCapacityExhausted != tt.expectedIsModelCapacityExhausted { + t.Errorf("IsModelCapacityExhausted = %v, want %v", result.IsModelCapacityExhausted, tt.expectedIsModelCapacityExhausted) + } }) } } @@ -531,13 +522,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { apiKeyAccount := &Account{Type: AccountTypeAPIKey} tests := []struct { - name string - account *Account - body string - expectedShouldRetry bool - expectedShouldRateLimit bool - minWait time.Duration - modelName string + name string + account *Account + body string + expectedShouldRetry bool + expectedShouldRateLimit bool + expectedIsModelCapacityExhausted bool + minWait time.Duration + modelName string }{ { name: "OAuth account with short delay (< 7s) - smart retry", @@ -635,6 +627,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { }`, expectedShouldRetry: false, expectedShouldRateLimit: true, + minWait: 7 * time.Second, modelName: "gemini-pro", }, { @@ -650,12 +643,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { ] } }`, - expectedShouldRetry: false, - expectedShouldRateLimit: true, - modelName: "gemini-3-pro-high", + expectedShouldRetry: true, + expectedShouldRateLimit: false, + expectedIsModelCapacityExhausted: true, + minWait: 1 * time.Second, + modelName: "gemini-3-pro-high", }, { - name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit", + name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use fixed wait", account: oauthAccount, body: `{ "error": { @@ -667,9 +662,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { "message": "No capacity available for model gemini-2.5-flash on the server" } }`, - expectedShouldRetry: false, - expectedShouldRateLimit: true, - modelName: "gemini-2.5-flash", + expectedShouldRetry: true, + expectedShouldRateLimit: false, + expectedIsModelCapacityExhausted: true, + minWait: 1 * time.Second, + modelName: "gemini-2.5-flash", }, { name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit", @@ -686,24 +683,33 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { }`, expectedShouldRetry: false, expectedShouldRateLimit: true, + minWait: 30 * time.Second, modelName: "claude-sonnet-4-5", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body)) + shouldRetry, shouldRateLimit, wait, model, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body)) if shouldRetry != tt.expectedShouldRetry { t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry) } if shouldRateLimit != tt.expectedShouldRateLimit { t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit) } + if isModelCapacityExhausted != tt.expectedIsModelCapacityExhausted { + t.Errorf("isModelCapacityExhausted = %v, want %v", isModelCapacityExhausted, tt.expectedIsModelCapacityExhausted) + } if shouldRetry { if wait < tt.minWait { t.Errorf("wait = %v, want >= %v", wait, tt.minWait) } } + if shouldRateLimit && tt.minWait > 0 { + if wait < tt.minWait { + t.Errorf("rate limit wait = %v, want >= %v", wait, tt.minWait) + } + } if (shouldRetry || shouldRateLimit) && model != tt.modelName { t.Errorf("modelName = %q, want %q", model, tt.modelName) } @@ -803,7 +809,7 @@ func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) { require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope") } -func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testing.T) { +func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) { upstream := &recordingOKUpstream{} account := &Account{ ID: 1, @@ -815,19 +821,15 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi Extra: map[string]any{ modelRateLimitsKey: map[string]any{ "claude-sonnet-4-5": map[string]any{ - // RFC3339 here is second-precision; keep it safely in the future. "rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339), }, }, }, } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) - defer cancel() - svc := &AntigravityGatewayService{} result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ - ctx: ctx, + ctx: context.Background(), prefix: "[test]", account: account, accessToken: "token", @@ -836,17 +838,21 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi requestedModel: "claude-sonnet-4-5", httpUpstream: upstream, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, }) - require.ErrorIs(t, err, context.DeadlineExceeded) require.Nil(t, result) - require.Equal(t, 0, upstream.calls, "should not call upstream while waiting on pre-check") + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel) + require.True(t, switchErr.IsStickySession) + require.Equal(t, 0, upstream.calls, "should not call upstream when switching on pre-check") } -func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t *testing.T) { +func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) { upstream := &recordingOKUpstream{} account := &Account{ ID: 2, @@ -875,7 +881,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t requestedModel: "claude-sonnet-4-5", httpUpstream: upstream, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, }) @@ -946,6 +952,22 @@ func TestIsAntigravityAccountSwitchError(t *testing.T) { } } +func TestResolveAntigravityForwardBaseURL_DefaultDaily(t *testing.T) { + t.Setenv(antigravityForwardBaseURLEnv, "") + + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) + defer func() { + antigravity.BaseURLs = oldBaseURLs + }() + + prodURL := "https://prod.test" + dailyURL := "https://daily.test" + antigravity.BaseURLs = []string{dailyURL, prodURL} + + resolved := resolveAntigravityForwardBaseURL() + require.Equal(t, dailyURL, resolved) +} + func TestAntigravityAccountSwitchError_Error(t *testing.T) { err := &AntigravityAccountSwitchError{ OriginalAccountID: 789, diff --git a/backend/internal/service/antigravity_single_account_retry_test.go b/backend/internal/service/antigravity_single_account_retry_test.go new file mode 100644 index 00000000..8b01cc31 --- /dev/null +++ b/backend/internal/service/antigravity_single_account_retry_test.go @@ -0,0 +1,904 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// 辅助函数:构造带 SingleAccountRetry 标记的 context +// --------------------------------------------------------------------------- + +func ctxWithSingleAccountRetry() context.Context { + return context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true) +} + +// --------------------------------------------------------------------------- +// 1. isSingleAccountRetry 测试 +// --------------------------------------------------------------------------- + +func TestIsSingleAccountRetry_True(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true) + require.True(t, isSingleAccountRetry(ctx)) +} + +func TestIsSingleAccountRetry_False_NoValue(t *testing.T) { + require.False(t, isSingleAccountRetry(context.Background())) +} + +func TestIsSingleAccountRetry_False_ExplicitFalse(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, false) + require.False(t, isSingleAccountRetry(ctx)) +} + +func TestIsSingleAccountRetry_False_WrongType(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, "true") + require.False(t, isSingleAccountRetry(ctx)) +} + +// --------------------------------------------------------------------------- +// 2. 常量验证 +// --------------------------------------------------------------------------- + +func TestSingleAccountRetryConstants(t *testing.T) { + require.Equal(t, 3, antigravitySingleAccountSmartRetryMaxAttempts, + "单账号原地重试最多 3 次") + require.Equal(t, 15*time.Second, antigravitySingleAccountSmartRetryMaxWait, + "单次最大等待 15s") + require.Equal(t, 30*time.Second, antigravitySingleAccountSmartRetryTotalMaxWait, + "总累计等待不超过 30s") +} + +// --------------------------------------------------------------------------- +// 3. handleSmartRetry + 503 + SingleAccountRetry → 走 handleSingleAccountRetryInPlace +// (而非设模型限流 + 切换账号) +// --------------------------------------------------------------------------- + +// TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace +// 核心场景:503 + retryDelay >= 7s + SingleAccountRetry 标记 +// → 不设模型限流、不切换账号,改为原地重试 +func TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace(t *testing.T) { + // 原地重试成功 + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 1, + Name: "acc-single", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + // 503 + 39s >= 7s 阈值 + MODEL_CAPACITY_EXHAUSTED + respBody := []byte(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"} + ], + "message": "No capacity available for model gemini-3-pro-high on the server" + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), // 关键:设置单账号标记 + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + // 关键断言:返回 resp(原地重试成功),而非 switchError(切换账号) + require.NotNil(t, result.resp, "should return successful response from in-place retry") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Nil(t, result.switchError, "should NOT return switchError in single account mode") + require.Nil(t, result.err) + + // 验证未设模型限流(单账号模式不应设限流) + require.Len(t, repo.modelRateLimitCalls, 0, + "should NOT set model rate limit in single account retry mode") + + // 验证确实调用了 upstream(原地重试) + require.GreaterOrEqual(t, len(upstream.calls), 1, "should have made at least one retry call") +} + +// TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches +// 对照组:503 + retryDelay >= 7s + 无 SingleAccountRetry 标记 +// → 照常设模型限流 + 切换账号 +func TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 2, + Name: "acc-multi", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 503 + 39s >= 7s 阈值(使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED, + // 因为 MODEL_CAPACITY_EXHAUSTED 走独立的重试路径,不触发 shouldRateLimitModel) + respBody := []byte(`{ + "error": { + "code": 503, + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), // 关键:无单账号标记 + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + // 对照:多账号模式返回 switchError + require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503") + require.Nil(t, result.resp, "should not return resp when switchError is set") + + // 对照:多账号模式应设模型限流 + require.Len(t, repo.modelRateLimitCalls, 1, + "multi-account mode SHOULD set model rate limit") +} + +// TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches +// 边界情况:429(非 503)+ SingleAccountRetry 标记 +// → 单账号原地重试仅针对 503,429 依然走切换账号逻辑 +func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 3, + Name: "acc-429", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 429 + 15s >= 7s 阈值 + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, // 429,不是 503 + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), // 有单账号标记 + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + // 429 即使有单账号标记,也应走切换账号 + require.NotNil(t, result.switchError, "429 should still return switchError even with SingleAccountRetry") + require.Len(t, repo.modelRateLimitCalls, 1, + "429 should still set model rate limit even with SingleAccountRetry") +} + +// --------------------------------------------------------------------------- +// 4. handleSmartRetry + 503 + 短延迟 + SingleAccountRetry → 智能重试耗尽后不设限流 +// --------------------------------------------------------------------------- + +// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit +// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503,不设限流 +func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) { + // 智能重试也返回 503 + failRespBody := `{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 4, + Name: "acc-short-503", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 0.1s < 7s 阈值 + respBody := []byte(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + // 关键断言:单账号 503 模式下,智能重试耗尽后直接返回 503 响应,不切换 + require.NotNil(t, result.resp, "should return 503 response directly for single account mode") + require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode) + require.Nil(t, result.switchError, "should NOT switch account in single account mode") + + // 关键断言:不设模型限流 + require.Len(t, repo.modelRateLimitCalls, 0, + "should NOT set model rate limit for 503 in single account mode") +} + +// TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit +// 对照组:503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流 +// 使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED,因为后者走独立的 60 次重试路径 +func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) { + failRespBody := `{ + "error": { + "code": 503, + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 5, + Name: "acc-multi-503", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "code": 503, + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), // 无单账号标记 + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + // 对照:多账号模式应返回 switchError + require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503") + // 对照:多账号模式应设模型限流 + require.Len(t, repo.modelRateLimitCalls, 1, + "multi-account mode should set model rate limit") +} + +// --------------------------------------------------------------------------- +// 5. handleSingleAccountRetryInPlace 直接测试 +// --------------------------------------------------------------------------- + +// TestHandleSingleAccountRetryInPlace_Success 原地重试成功 +func TestHandleSingleAccountRetryInPlace_Success(t *testing.T) { + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + + account := &Account{ + ID: 10, + Name: "acc-inplace-ok", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro") + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp, "should return successful response") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Nil(t, result.switchError, "should not switch account on success") + require.Nil(t, result.err) +} + +// TestHandleSingleAccountRetryInPlace_AllRetriesFail 所有重试都失败,返回 503(不设限流) +func TestHandleSingleAccountRetryInPlace_AllRetriesFail(t *testing.T) { + // 构造 3 个 503 响应(对应 3 次原地重试) + var responses []*http.Response + var errors []error + for i := 0; i < antigravitySingleAccountSmartRetryMaxAttempts; i++ { + responses = append(responses, &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`)), + }) + errors = append(errors, nil) + } + upstream := &mockSmartRetryUpstream{ + responses: responses, + errors: errors, + } + + account := &Account{ + ID: 11, + Name: "acc-inplace-fail", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + origBody := []byte(`{"error":{"code":503,"status":"UNAVAILABLE"}}`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{"X-Test": {"original"}}, + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSingleAccountRetryInPlace(params, resp, origBody, "https://ag-1.test", 1*time.Second, "gemini-3-pro") + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + // 关键:返回 503 resp,不返回 switchError + require.NotNil(t, result.resp, "should return 503 response directly") + require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode) + require.Nil(t, result.switchError, "should NOT return switchError - let Handler handle it") + require.Nil(t, result.err) + + // 验证确实重试了指定次数 + require.Len(t, upstream.calls, antigravitySingleAccountSmartRetryMaxAttempts, + "should have made exactly maxAttempts retry calls") +} + +// TestHandleSingleAccountRetryInPlace_WaitDurationClamped 等待时间被限制在 [min, max] 范围 +func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) { + // 用短延迟的成功响应,只验证不 panic + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + + account := &Account{ + ID: 12, + Name: "acc-clamp", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + } + + svc := &AntigravityGatewayService{} + + // 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait + result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro") + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp) + require.Equal(t, http.StatusOK, result.resp.StatusCode) +} + +// TestHandleSingleAccountRetryInPlace_ContextCanceled context 取消时立即返回 +func TestHandleSingleAccountRetryInPlace_ContextCanceled(t *testing.T) { + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{nil}, + errors: []error{nil}, + } + + account := &Account{ + ID: 13, + Name: "acc-cancel", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + } + + ctx, cancel := context.WithCancel(context.Background()) + ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true) + cancel() // 立即取消 + + params := antigravityRetryLoopParams{ + ctx: ctx, + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro") + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.Error(t, result.err, "should return context error") + // 不应调用 upstream(因为在等待阶段就被取消了) + require.Len(t, upstream.calls, 0, "should not call upstream when context is canceled") +} + +// TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry 网络错误时继续重试 +func TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry(t *testing.T) { + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + // 第1次网络错误(nil resp),第2次成功 + responses: []*http.Response{nil, successResp}, + errors: []error{nil, nil}, + } + + account := &Account{ + ID: 14, + Name: "acc-net-retry", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro") + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp, "should return successful response after network error recovery") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Len(t, upstream.calls, 2, "first call fails (network error), second succeeds") +} + +// --------------------------------------------------------------------------- +// 6. antigravityRetryLoop 预检查:单账号模式跳过限流 +// --------------------------------------------------------------------------- + +// TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit +// 预检查中,如果有 SingleAccountRetry 标记,即使账号已限流也跳过直接发请求 +func TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit(t *testing.T) { + // 创建一个已设模型限流的账号 + upstream := &recordingOKUpstream{} + account := &Account{ + ID: 20, + Name: "acc-rate-limited", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339), + }, + }, + }, + } + + svc := &AntigravityGatewayService{} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + requestedModel: "claude-sonnet-4-5", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.NoError(t, err, "should not return error") + require.NotNil(t, result, "should return result") + require.NotNil(t, result.resp, "should have response") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + // 关键:尽管限流了,有 SingleAccountRetry 标记时仍然到达了 upstream + require.Equal(t, 1, upstream.calls, "should have reached upstream despite rate limit") +} + +// TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit +// 对照组:无 SingleAccountRetry + 已限流 → 预检查返回 switchError +func TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit(t *testing.T) { + upstream := &recordingOKUpstream{} + account := &Account{ + ID: 21, + Name: "acc-rate-limited-multi", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339), + }, + }, + }, + } + + svc := &AntigravityGatewayService{} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: context.Background(), // 无单账号标记 + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + requestedModel: "claude-sonnet-4-5", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.Nil(t, result, "should not return result on rate limit switch") + require.NotNil(t, err, "should return error") + + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr, "should return AntigravityAccountSwitchError") + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel) + + // upstream 不应被调用(预检查就短路了) + require.Equal(t, 0, upstream.calls, "upstream should NOT be called when pre-check blocks") +} + +// --------------------------------------------------------------------------- +// 7. 端到端集成场景测试 +// --------------------------------------------------------------------------- + +// TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E +// 端到端场景:503 + 单账号 + 原地重试第2次成功 +func TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E(t *testing.T) { + // 第1次原地重试仍返回 503,第2次成功 + fail503Body := `{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + resp503 := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(fail503Body)), + } + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{resp503, successResp}, + errors: []error{nil, nil}, + } + + account := &Account{ + ID: 30, + Name: "acc-e2e", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro") + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp, "should return successful response after 2nd attempt") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Nil(t, result.switchError) + require.Len(t, upstream.calls, 2, "first 503, second OK") +} + +// TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E +// 通过 antigravityRetryLoop → handleSmartRetry → handleSingleAccountRetryInPlace 完整链路 +func TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E(t *testing.T) { + // 初始请求返回 503 + 长延迟 + initial503Body := []byte(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "10s"} + ], + "message": "No capacity available" + } + }`) + initial503Resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(initial503Body)), + } + + // 原地重试成功 + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + + upstream := &mockSmartRetryUpstream{ + // 第1次调用(retryLoop 主循环)返回 503 + // 第2次调用(handleSingleAccountRetryInPlace 原地重试)返回 200 + responses: []*http.Response{initial503Resp, successResp}, + errors: []error{nil, nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 31, + Name: "acc-e2e-loop", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + svc := &AntigravityGatewayService{} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.NoError(t, err, "should not return error on successful retry") + require.NotNil(t, result, "should return result") + require.NotNil(t, result.resp, "should return response") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + + // 验证未设模型限流 + require.Len(t, repo.modelRateLimitCalls, 0, + "should NOT set model rate limit in single account retry mode") +} diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go index 623dfec5..432c80e5 100644 --- a/backend/internal/service/antigravity_smart_retry_test.go +++ b/backend/internal/service/antigravity_smart_retry_test.go @@ -13,6 +13,23 @@ import ( "github.com/stretchr/testify/require" ) +// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock +// 仅关注 DeleteSessionAccountID 的调用记录 +type stubSmartRetryCache struct { + GatewayCache // 嵌入接口,未实现的方法 panic(确保只调用预期方法) + deleteCalls []deleteSessionCall +} + +type deleteSessionCall struct { + groupID int64 + sessionHash string +} + +func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID int64, sessionHash string) error { + c.deleteCalls = append(c.deleteCalls, deleteSessionCall{groupID: groupID, sessionHash: sessionHash}) + return nil +} + // mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream type mockSmartRetryUpstream struct { responses []*http.Response @@ -58,7 +75,7 @@ func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) { accessToken: "token", action: "generateContent", body: []byte(`{"input":"test"}`), - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -110,7 +127,7 @@ func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) { body: []byte(`{"input":"test"}`), accountRepo: repo, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -177,7 +194,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) { action: "generateContent", body: []byte(`{"input":"test"}`), httpUpstream: upstream, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -198,7 +215,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) { // TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError 测试智能重试失败后返回 switchError func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *testing.T) { - // 智能重试后仍然返回 429(需要提供 3 个响应,因为智能重试最多 3 次) + // 智能重试后仍然返回 429(需要提供 1 个响应,因为智能重试最多 1 次) failRespBody := `{ "error": { "status": "RESOURCE_EXHAUSTED", @@ -213,19 +230,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test Header: http.Header{}, Body: io.NopCloser(strings.NewReader(failRespBody)), } - failResp2 := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: http.Header{}, - Body: io.NopCloser(strings.NewReader(failRespBody)), - } - failResp3 := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: http.Header{}, - Body: io.NopCloser(strings.NewReader(failRespBody)), - } upstream := &mockSmartRetryUpstream{ - responses: []*http.Response{failResp1, failResp2, failResp3}, - errors: []error{nil, nil, nil}, + responses: []*http.Response{failResp1}, + errors: []error{nil}, } repo := &stubAntigravityAccountRepo{} @@ -236,7 +243,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test Platform: PlatformAntigravity, } - // 3s < 7s 阈值,应该触发智能重试(最多 3 次) + // 3s < 7s 阈值,应该触发智能重试(最多 1 次) respBody := []byte(`{ "error": { "status": "RESOURCE_EXHAUSTED", @@ -262,7 +269,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test httpUpstream: upstream, accountRepo: repo, isStickySession: false, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -284,11 +291,12 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test // 验证模型限流已设置 require.Len(t, repo.modelRateLimitCalls, 1) require.Equal(t, "gemini-3-flash", repo.modelRateLimitCalls[0].modelKey) - require.Len(t, upstream.calls, 3, "should have made three retry calls (max attempts)") + require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)") } -// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError -func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) { +// TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess 测试 503 MODEL_CAPACITY_EXHAUSTED 重试成功 +// MODEL_CAPACITY_EXHAUSTED 使用固定 1s 间隔重试,不切换账号 +func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T) { repo := &stubAntigravityAccountRepo{} account := &Account{ ID: 3, @@ -297,7 +305,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi Platform: PlatformAntigravity, } - // 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值 + // 503 + MODEL_CAPACITY_EXHAUSTED + 39s(上游 retryDelay 应被忽略,使用固定 1s) respBody := []byte(`{ "error": { "code": 503, @@ -315,6 +323,14 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi Body: io.NopCloser(bytes.NewReader(respBody)), } + // mock: 第 1 次重试返回 200 成功 + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{ + {StatusCode: http.StatusOK, Header: http.Header{}, Body: io.NopCloser(strings.NewReader(`{"ok":true}`))}, + }, + errors: []error{nil}, + } + params := antigravityRetryLoopParams{ ctx: context.Background(), prefix: "[test]", @@ -323,8 +339,9 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi action: "generateContent", body: []byte(`{"input":"test"}`), accountRepo: repo, + httpUpstream: upstream, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -336,16 +353,67 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi require.NotNil(t, result) require.Equal(t, smartRetryActionBreakWithResp, result.action) - require.Nil(t, result.resp) + require.NotNil(t, result.resp, "should return successful response") + require.Equal(t, http.StatusOK, result.resp.StatusCode) require.Nil(t, result.err) - require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted") - require.Equal(t, account.ID, result.switchError.OriginalAccountID) - require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel) - require.True(t, result.switchError.IsStickySession) + require.Nil(t, result.switchError, "MODEL_CAPACITY_EXHAUSTED should not return switchError") - // 验证模型限流已设置 - require.Len(t, repo.modelRateLimitCalls, 1) - require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey) + // 不应设置模型限流 + require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit") + require.Len(t, upstream.calls, 1, "should have made one retry call before success") +} + +// TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel 测试 MODEL_CAPACITY_EXHAUSTED 上下文取消 +func TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 3, + Name: "acc-3", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + // 立即取消上下文,验证重试循环能正确退出 + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + params := antigravityRetryLoopParams{ + ctx: ctx, + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"}) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.Error(t, result.err, "should return context error") + require.Nil(t, result.switchError, "should not return switchError on context cancel") + require.Empty(t, repo.modelRateLimitCalls, "should not set model rate limit on context cancel") } // TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑 @@ -380,7 +448,7 @@ func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing accessToken: "token", action: "generateContent", body: []byte(`{"input":"test"}`), - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -429,7 +497,7 @@ func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T) accessToken: "token", action: "generateContent", body: []byte(`{"input":"test"}`), - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -480,7 +548,7 @@ func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) { action: "generateContent", body: []byte(`{"input":"test"}`), accountRepo: repo, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -541,7 +609,7 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing httpUpstream: upstream, accountRepo: repo, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, }) @@ -556,19 +624,15 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing require.True(t, switchErr.IsStickySession) } -// TestHandleSmartRetry_NetworkError_ContinuesRetry 测试网络错误时继续重试 -func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) { - // 第一次网络错误,第二次成功 - successResp := &http.Response{ - StatusCode: http.StatusOK, - Header: http.Header{}, - Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), - } +// TestHandleSmartRetry_NetworkError_ExhaustsRetry 测试网络错误时(maxAttempts=1)直接耗尽重试并切换账号 +func TestHandleSmartRetry_NetworkError_ExhaustsRetry(t *testing.T) { + // 唯一一次重试遇到网络错误(nil response) upstream := &mockSmartRetryUpstream{ - responses: []*http.Response{nil, successResp}, // 第一次返回 nil(模拟网络错误) - errors: []error{nil, nil}, // mock 不返回 error,靠 nil response 触发 + responses: []*http.Response{nil}, // 返回 nil(模拟网络错误) + errors: []error{nil}, // mock 不返回 error,靠 nil response 触发 } + repo := &stubAntigravityAccountRepo{} account := &Account{ ID: 8, Name: "acc-8", @@ -600,7 +664,8 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) { action: "generateContent", body: []byte(`{"input":"test"}`), httpUpstream: upstream, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -612,10 +677,15 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) { require.NotNil(t, result) require.Equal(t, smartRetryActionBreakWithResp, result.action) - require.NotNil(t, result.resp, "should return successful response after network error recovery") - require.Equal(t, http.StatusOK, result.resp.StatusCode) - require.Nil(t, result.switchError, "should not return switchError on success") - require.Len(t, upstream.calls, 2, "should have made two retry calls") + require.Nil(t, result.resp, "should not return resp when switchError is set") + require.NotNil(t, result.switchError, "should return switchError after network error exhausted retry") + require.Equal(t, account.ID, result.switchError.OriginalAccountID) + require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel) + require.Len(t, upstream.calls, 1, "should have made one retry call") + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } // TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit 测试无 retryDelay 时使用默认 1 分钟限流 @@ -653,7 +723,7 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) { body: []byte(`{"input":"test"}`), accountRepo: repo, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -674,3 +744,617 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) { require.Len(t, repo.modelRateLimitCalls, 1) require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } + +// --------------------------------------------------------------------------- +// 以下测试覆盖本次改动: +// 1. antigravitySmartRetryMaxAttempts = 1(仅重试 1 次) +// 2. 智能重试失败后清除粘性会话绑定(DeleteSessionAccountID) +// --------------------------------------------------------------------------- + +// TestSmartRetryMaxAttempts_VerifyConstant 验证常量值为 1 +func TestSmartRetryMaxAttempts_VerifyConstant(t *testing.T) { + require.Equal(t, 1, antigravitySmartRetryMaxAttempts, + "antigravitySmartRetryMaxAttempts should be 1 to prevent repeated rate limiting") +} + +// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession +// 核心场景:粘性会话 + 短延迟重试失败 → 必须清除粘性绑定 +func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession(t *testing.T) { + failRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 10, + Name: "acc-10", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 42, + sessionHash: "sticky-hash-abc", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + // 验证返回 switchError + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession, "switchError should carry IsStickySession=true") + require.Equal(t, account.ID, result.switchError.OriginalAccountID) + + // 核心断言:DeleteSessionAccountID 被调用,且参数正确 + require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID exactly once") + require.Equal(t, int64(42), cache.deleteCalls[0].groupID) + require.Equal(t, "sticky-hash-abc", cache.deleteCalls[0].sessionHash) + + // 验证仅重试 1 次 + require.Len(t, upstream.calls, 1, "should make exactly 1 retry call (maxAttempts=1)") + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) +} + +// TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession +// 非粘性会话 + 短延迟重试失败 → 不应调用 DeleteSessionAccountID(sessionHash 为空) +func TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession(t *testing.T) { + failRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 11, + Name: "acc-11", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: false, + groupID: 42, + sessionHash: "", // 非粘性会话,sessionHash 为空 + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.switchError) + require.False(t, result.switchError.IsStickySession) + + // 核心断言:sessionHash 为空时不应调用 DeleteSessionAccountID + require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID when sessionHash is empty") +} + +// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic +// 边界:cache 为 nil 时不应 panic +func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic(t *testing.T) { + failRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 12, + Name: "acc-12", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 42, + sessionHash: "sticky-hash-nil-cache", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + // cache 为 nil,不应 panic + svc := &AntigravityGatewayService{cache: nil} + require.NotPanics(t, func() { + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession) + }) +} + +// TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession +// 重试成功时不应清除粘性会话(只有失败才清除) +func TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession(t *testing.T) { + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 13, + Name: "acc-13", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + isStickySession: true, + groupID: 42, + sessionHash: "sticky-hash-success", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp, "should return successful response") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Nil(t, result.switchError, "should not return switchError on success") + + // 核心断言:重试成功时不应清除粘性会话 + require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID on successful retry") +} + +// TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry +// 长延迟路径(情况1)在 handleSmartRetry 中不直接调用 DeleteSessionAccountID +// (清除由 handler 层的 shouldClearStickySession 在下次请求时处理) +func TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 14, + Name: "acc-14", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 15s >= 7s 阈值 → 走长延迟路径 + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + isStickySession: true, + groupID: 42, + sessionHash: "sticky-hash-long-delay", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession) + + // 长延迟路径不在 handleSmartRetry 中调用 DeleteSessionAccountID + // (由上游 handler 的 shouldClearStickySession 处理) + require.Len(t, cache.deleteCalls, 0, + "long delay path should NOT call DeleteSessionAccountID in handleSmartRetry (handled by handler layer)") +} + +// TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession +// 网络错误耗尽重试 + 粘性会话 → 也应清除粘性绑定 +func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t *testing.T) { + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{nil}, // 网络错误 + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 15, + Name: "acc-15", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 99, + sessionHash: "sticky-net-error", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession) + + // 核心断言:网络错误耗尽重试后也应清除粘性绑定 + require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID after network error exhausts retry") + require.Equal(t, int64(99), cache.deleteCalls[0].groupID) + require.Equal(t, "sticky-net-error", cache.deleteCalls[0].sessionHash) +} + +// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession +// 429 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定 +func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) { + failRespBody := `{ + "error": { + "code": 429, + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 16, + Name: "acc-16", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "code": 429, + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 77, + sessionHash: "sticky-503-short", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession) + + // 验证粘性绑定被清除 + require.Len(t, cache.deleteCalls, 1) + require.Equal(t, int64(77), cache.deleteCalls[0].groupID) + require.Equal(t, "sticky-503-short", cache.deleteCalls[0].sessionHash) + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "gemini-3-pro", repo.modelRateLimitCalls[0].modelKey) +} + +// TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates +// 集成测试:antigravityRetryLoop → handleSmartRetry → switchError 传播 +// 验证 IsStickySession 正确传递到上层,且粘性绑定被清除 +func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates(t *testing.T) { + // 初始 429 响应 + initialRespBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + initialResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(initialRespBody)), + } + + // 智能重试也返回 429 + retryRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + retryResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(retryRespBody)), + } + + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{initialResp, retryResp}, + errors: []error{nil, nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 17, + Name: "acc-17", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + svc := &AntigravityGatewayService{cache: cache} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 55, + sessionHash: "sticky-loop-test", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.Nil(t, result, "should not return result when switchError") + require.NotNil(t, err, "should return error") + + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr, "error should be AntigravityAccountSwitchError") + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, "claude-opus-4-6", switchErr.RateLimitedModel) + require.True(t, switchErr.IsStickySession, "IsStickySession must propagate through retryLoop") + + // 验证粘性绑定被清除 + require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry") + require.Equal(t, int64(55), cache.deleteCalls[0].groupID) + require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash) +} diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index 1eb740f9..068d6a08 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -7,12 +7,14 @@ import ( "log/slog" "strconv" "strings" + "sync" "time" ) const ( antigravityTokenRefreshSkew = 3 * time.Minute antigravityTokenCacheSkew = 5 * time.Minute + antigravityBackfillCooldown = 5 * time.Minute ) // AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) @@ -23,6 +25,7 @@ type AntigravityTokenProvider struct { accountRepo AccountRepository tokenCache AntigravityTokenCache antigravityOAuthService *AntigravityOAuthService + backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time } func NewAntigravityTokenProvider( @@ -93,13 +96,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * if err != nil { return "", err } - newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } - account.Credentials = newCredentials + p.mergeCredentials(account, tokenInfo) if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr) } @@ -113,6 +110,21 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return "", errors.New("access_token not found in credentials") } + // 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现 + // "Invalid project resource name projects/"。 + // 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。 + if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil { + if p.shouldAttemptBackfill(account.ID) { + p.markBackfillAttempted(account.ID) + if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" { + account.Credentials["project_id"] = projectID + if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr) + } + } + } + } + // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) if p.tokenCache != nil { latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) @@ -144,6 +156,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return accessToken, nil } +// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段 +func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) { + newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + account.Credentials = newCredentials +} + +// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试) +func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool { + if v, ok := p.backfillCooldown.Load(accountID); ok { + if lastAttempt, ok := v.(time.Time); ok { + return time.Since(lastAttempt) > antigravityBackfillCooldown + } + } + return true +} + +func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) { + p.backfillCooldown.Store(accountID, time.Now()) +} + func AntigravityTokenCacheKey(account *Account) string { projectID := strings.TrimSpace(account.GetCredential("project_id")) if projectID != "" { diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index d66059dd..07523597 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -1,6 +1,10 @@ package service -import "time" +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" +) // API Key status constants const ( @@ -19,10 +23,14 @@ type APIKey struct { Status string IPWhitelist []string IPBlacklist []string - CreatedAt time.Time - UpdatedAt time.Time - User *User - Group *Group + // 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。 + CompiledIPWhitelist *ip.CompiledIPRules `json:"-"` + CompiledIPBlacklist *ip.CompiledIPRules `json:"-"` + LastUsedAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time + User *User + Group *Group // Quota fields Quota float64 // Quota limit in USD (0 = unlimited) diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index d15b5817..4240be23 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -44,6 +44,10 @@ type APIKeyAuthGroupSnapshot struct { ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index f5bba7d0..30eb8d74 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -6,8 +6,7 @@ import ( "encoding/hex" "errors" "fmt" - "math/rand" - "sync" + "math/rand/v2" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -23,12 +22,6 @@ type apiKeyAuthCacheConfig struct { singleflight bool } -var ( - jitterRandMu sync.Mutex - // 认证缓存抖动使用独立随机源,避免全局 Seed - jitterRand = rand.New(rand.NewSource(time.Now().UnixNano())) -) - func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig { if cfg == nil { return apiKeyAuthCacheConfig{} @@ -56,6 +49,8 @@ func (c apiKeyAuthCacheConfig) negativeEnabled() bool { return c.negativeTTL > 0 } +// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。 +// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。 func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { if ttl <= 0 { return ttl @@ -68,9 +63,7 @@ func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { percent = 100 } delta := float64(percent) / 100 - jitterRandMu.Lock() - randVal := jitterRand.Float64() - jitterRandMu.Unlock() + randVal := rand.Float64() factor := 1 - delta + randVal*(2*delta) if factor <= 0 { return ttl @@ -238,6 +231,10 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice4K: apiKey.Group.ImagePrice4K, + SoraImagePrice360: apiKey.Group.SoraImagePrice360, + SoraImagePrice540: apiKey.Group.SoraImagePrice540, + SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, FallbackGroupID: apiKey.Group.FallbackGroupID, FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest, @@ -288,6 +285,10 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice4K: snapshot.Group.ImagePrice4K, + SoraImagePrice360: snapshot.Group.SoraImagePrice360, + SoraImagePrice540: snapshot.Group.SoraImagePrice540, + SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, FallbackGroupID: snapshot.Group.FallbackGroupID, FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest, @@ -297,5 +298,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho SupportedModelScopes: snapshot.Group.SupportedModelScopes, } } + s.compileAPIKeyIPRules(apiKey) return apiKey } diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index cb1dd60a..0d073077 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -5,6 +5,8 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "strconv" + "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -32,6 +34,9 @@ var ( const ( apiKeyMaxErrorsPerHour = 20 + apiKeyLastUsedMinTouch = 30 * time.Second + // DB 写失败后的短退避,避免请求路径持续同步重试造成写风暴与高延迟。 + apiKeyLastUsedFailBackoff = 5 * time.Second ) type APIKeyRepository interface { @@ -58,6 +63,7 @@ type APIKeyRepository interface { // Quota methods IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) + UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error } // APIKeyCache defines cache operations for API key service @@ -125,6 +131,8 @@ type APIKeyService struct { authCacheL1 *ristretto.Cache authCfg apiKeyAuthCacheConfig authGroup singleflight.Group + lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time) + lastUsedTouchSF singleflight.Group } // NewAPIKeyService 创建API Key服务实例 @@ -150,6 +158,14 @@ func NewAPIKeyService( return svc } +func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) { + if apiKey == nil { + return + } + apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist) + apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist) +} + // GenerateKey 生成随机API Key func (s *APIKeyService) GenerateKey() (string, error) { // 生成32字节随机数据 @@ -324,6 +340,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK } s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } @@ -355,6 +372,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } @@ -367,6 +385,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } } @@ -383,6 +402,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } } else { @@ -394,6 +414,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } } @@ -403,6 +424,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro return nil, fmt.Errorf("get api key: %w", err) } apiKey.Key = key + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } @@ -502,6 +524,7 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req } s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } @@ -527,6 +550,7 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro if err := s.apiKeyRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete api key: %w", err) } + s.lastUsedTouchL1.Delete(id) return nil } @@ -558,6 +582,38 @@ func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, * return apiKey, user, nil } +// TouchLastUsed 通过防抖更新 api_keys.last_used_at,减少高频写放大。 +// 该操作为尽力而为,不应阻塞主请求链路。 +func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error { + if keyID <= 0 { + return nil + } + + now := time.Now() + if v, ok := s.lastUsedTouchL1.Load(keyID); ok { + if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) { + return nil + } + } + + _, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) { + latest := time.Now() + if v, ok := s.lastUsedTouchL1.Load(keyID); ok { + if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) { + return nil, nil + } + } + + if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil { + s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedFailBackoff)) + return nil, fmt.Errorf("touch api key last used: %w", err) + } + s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedMinTouch)) + return nil, nil + }) + return err +} + // IncrementUsage 增加API Key使用次数(可选:用于统计) func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error { // 使用Redis计数器 diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 14ecbf39..2357813b 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -103,6 +103,10 @@ func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount panic("unexpected IncrementQuotaUsed call") } +func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + panic("unexpected UpdateLastUsed call") +} + type authCacheStub struct { getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) setAuthKeys []string diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index d4d12144..79757808 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -24,10 +24,13 @@ import ( // - deleteErr: 模拟 Delete 返回的错误 // - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证 type apiKeyRepoStub struct { - apiKey *APIKey // GetKeyAndOwnerID 的返回值 - getByIDErr error // GetKeyAndOwnerID 的错误返回值 - deleteErr error // Delete 的错误返回值 - deletedIDs []int64 // 记录已删除的 API Key ID 列表 + apiKey *APIKey // GetKeyAndOwnerID 的返回值 + getByIDErr error // GetKeyAndOwnerID 的错误返回值 + deleteErr error // Delete 的错误返回值 + deletedIDs []int64 // 记录已删除的 API Key ID 列表 + updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error + touchedIDs []int64 + touchedUsedAts []time.Time } // 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题 @@ -122,6 +125,15 @@ func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amoun panic("unexpected IncrementQuotaUsed call") } +func (s *apiKeyRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + s.touchedIDs = append(s.touchedIDs, id) + s.touchedUsedAts = append(s.touchedUsedAts, usedAt) + if s.updateLastUsed != nil { + return s.updateLastUsed(ctx, id, usedAt) + } + return nil +} + // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。 // @@ -214,12 +226,15 @@ func TestApiKeyService_Delete_Success(t *testing.T) { } cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} + svc.lastUsedTouchL1.Store(int64(42), time.Now()) err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7 require.NoError(t, err) require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除 require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除 require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys) + _, exists := svc.lastUsedTouchL1.Load(int64(42)) + require.False(t, exists, "delete should clear touch debounce cache") } // TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。 diff --git a/backend/internal/service/api_key_service_touch_last_used_test.go b/backend/internal/service/api_key_service_touch_last_used_test.go new file mode 100644 index 00000000..b49bf9ce --- /dev/null +++ b/backend/internal/service/api_key_service_touch_last_used_test.go @@ -0,0 +1,160 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestAPIKeyService_TouchLastUsed_InvalidKeyID(t *testing.T) { + repo := &apiKeyRepoStub{ + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + return errors.New("should not be called") + }, + } + svc := &APIKeyService{apiKeyRepo: repo} + + require.NoError(t, svc.TouchLastUsed(context.Background(), 0)) + require.NoError(t, svc.TouchLastUsed(context.Background(), -1)) + require.Empty(t, repo.touchedIDs) +} + +func TestAPIKeyService_TouchLastUsed_FirstTouchSucceeds(t *testing.T) { + repo := &apiKeyRepoStub{} + svc := &APIKeyService{apiKeyRepo: repo} + + err := svc.TouchLastUsed(context.Background(), 123) + require.NoError(t, err) + require.Equal(t, []int64{123}, repo.touchedIDs) + require.Len(t, repo.touchedUsedAts, 1) + require.False(t, repo.touchedUsedAts[0].IsZero()) + + cached, ok := svc.lastUsedTouchL1.Load(int64(123)) + require.True(t, ok, "successful touch should update debounce cache") + _, isTime := cached.(time.Time) + require.True(t, isTime) +} + +func TestAPIKeyService_TouchLastUsed_DebouncedWithinWindow(t *testing.T) { + repo := &apiKeyRepoStub{} + svc := &APIKeyService{apiKeyRepo: repo} + + require.NoError(t, svc.TouchLastUsed(context.Background(), 123)) + require.NoError(t, svc.TouchLastUsed(context.Background(), 123)) + + require.Equal(t, []int64{123}, repo.touchedIDs, "second touch within debounce window should not hit repository") +} + +func TestAPIKeyService_TouchLastUsed_ExpiredDebounceTouchesAgain(t *testing.T) { + repo := &apiKeyRepoStub{} + svc := &APIKeyService{apiKeyRepo: repo} + + require.NoError(t, svc.TouchLastUsed(context.Background(), 123)) + + // 强制将 debounce 时间回拨到窗口之外,触发第二次写库。 + svc.lastUsedTouchL1.Store(int64(123), time.Now().Add(-apiKeyLastUsedMinTouch-time.Second)) + + require.NoError(t, svc.TouchLastUsed(context.Background(), 123)) + require.Len(t, repo.touchedIDs, 2) + require.Equal(t, int64(123), repo.touchedIDs[0]) + require.Equal(t, int64(123), repo.touchedIDs[1]) +} + +func TestAPIKeyService_TouchLastUsed_RepoError(t *testing.T) { + repo := &apiKeyRepoStub{ + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + return errors.New("db write failed") + }, + } + svc := &APIKeyService{apiKeyRepo: repo} + + err := svc.TouchLastUsed(context.Background(), 123) + require.Error(t, err) + require.ErrorContains(t, err, "touch api key last used") + require.Equal(t, []int64{123}, repo.touchedIDs) + + cached, ok := svc.lastUsedTouchL1.Load(int64(123)) + require.True(t, ok, "failed touch should still update retry debounce cache") + _, isTime := cached.(time.Time) + require.True(t, isTime) +} + +func TestAPIKeyService_TouchLastUsed_RepoErrorDebounced(t *testing.T) { + repo := &apiKeyRepoStub{ + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + return errors.New("db write failed") + }, + } + svc := &APIKeyService{apiKeyRepo: repo} + + firstErr := svc.TouchLastUsed(context.Background(), 456) + require.Error(t, firstErr) + require.ErrorContains(t, firstErr, "touch api key last used") + + secondErr := svc.TouchLastUsed(context.Background(), 456) + require.NoError(t, secondErr, "failed touch should be debounced and skip immediate retry") + require.Equal(t, []int64{456}, repo.touchedIDs, "debounced retry should not hit repository again") +} + +type touchSingleflightRepo struct { + *apiKeyRepoStub + mu sync.Mutex + calls int + blockCh chan struct{} +} + +func (r *touchSingleflightRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + r.mu.Lock() + r.calls++ + r.mu.Unlock() + <-r.blockCh + return nil +} + +func TestAPIKeyService_TouchLastUsed_ConcurrentFirstTouchDeduplicated(t *testing.T) { + repo := &touchSingleflightRepo{ + apiKeyRepoStub: &apiKeyRepoStub{}, + blockCh: make(chan struct{}), + } + svc := &APIKeyService{apiKeyRepo: repo} + + const workers = 20 + startCh := make(chan struct{}) + errCh := make(chan error, workers) + var wg sync.WaitGroup + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-startCh + errCh <- svc.TouchLastUsed(context.Background(), 321) + }() + } + + close(startCh) + + require.Eventually(t, func() bool { + repo.mu.Lock() + defer repo.mu.Unlock() + return repo.calls >= 1 + }, time.Second, 10*time.Millisecond) + + close(repo.blockCh) + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Equal(t, 1, repo.calls, "并发首次 touch 只应写库一次") +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index fb8aaf9c..fe3a0f25 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -7,13 +7,13 @@ import ( "encoding/hex" "errors" "fmt" - "log" "net/mail" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" @@ -56,15 +56,20 @@ type JWTClaims struct { // AuthService 认证服务 type AuthService struct { - userRepo UserRepository - redeemRepo RedeemCodeRepository - refreshTokenCache RefreshTokenCache - cfg *config.Config - settingService *SettingService - emailService *EmailService - turnstileService *TurnstileService - emailQueueService *EmailQueueService - promoService *PromoService + userRepo UserRepository + redeemRepo RedeemCodeRepository + refreshTokenCache RefreshTokenCache + cfg *config.Config + settingService *SettingService + emailService *EmailService + turnstileService *TurnstileService + emailQueueService *EmailQueueService + promoService *PromoService + defaultSubAssigner DefaultSubscriptionAssigner +} + +type DefaultSubscriptionAssigner interface { + AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) } // NewAuthService 创建认证服务实例 @@ -78,17 +83,19 @@ func NewAuthService( turnstileService *TurnstileService, emailQueueService *EmailQueueService, promoService *PromoService, + defaultSubAssigner DefaultSubscriptionAssigner, ) *AuthService { return &AuthService{ - userRepo: userRepo, - redeemRepo: redeemRepo, - refreshTokenCache: refreshTokenCache, - cfg: cfg, - settingService: settingService, - emailService: emailService, - turnstileService: turnstileService, - emailQueueService: emailQueueService, - promoService: promoService, + userRepo: userRepo, + redeemRepo: redeemRepo, + refreshTokenCache: refreshTokenCache, + cfg: cfg, + settingService: settingService, + emailService: emailService, + turnstileService: turnstileService, + emailQueueService: emailQueueService, + promoService: promoService, + defaultSubAssigner: defaultSubAssigner, } } @@ -118,12 +125,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw // 验证邀请码 redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) if err != nil { - log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err) + logger.LegacyPrintf("service.auth", "[Auth] Invalid invitation code: %s, error: %v", invitationCode, err) return "", nil, ErrInvitationCodeInvalid } // 检查类型和状态 if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { - log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status) + logger.LegacyPrintf("service.auth", "[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status) return "", nil, ErrInvitationCodeInvalid } invitationRedeemCode = redeemCode @@ -134,7 +141,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw // 如果邮件验证已开启但邮件服务未配置,拒绝注册 // 这是一个配置错误,不应该允许绕过验证 if s.emailService == nil { - log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verification enabled but email service not configured, rejecting registration") return "", nil, ErrServiceUnavailable } if verifyCode == "" { @@ -149,7 +156,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error checking email exists: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error checking email exists: %v", err) return "", nil, ErrServiceUnavailable } if existsEmail { @@ -185,22 +192,23 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw if errors.Is(err, ErrEmailExists) { return "", nil, ErrEmailExists } - log.Printf("[Auth] Database error creating user: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } + s.assignDefaultSubscriptions(ctx, user.ID) // 标记邀请码为已使用(如果使用了邀请码) if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { // 邀请码标记失败不影响注册,只记录日志 - log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err) } } // 应用优惠码(如果提供且功能已启用) if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { // 优惠码应用失败不影响注册,只记录日志 - log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to apply promo code for user %d: %v", user.ID, err) } else { // 重新获取用户信息以获取更新后的余额 if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil { @@ -237,7 +245,7 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error checking email exists: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error checking email exists: %v", err) return ErrServiceUnavailable } if existsEmail { @@ -260,11 +268,11 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { // SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时 func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { - log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] SendVerifyCodeAsync called for email: %s", email) // 检查是否开放注册(默认关闭) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { - log.Println("[Auth] Registration is disabled") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Registration is disabled") return nil, ErrRegDisabled } @@ -275,17 +283,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error checking email exists: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error checking email exists: %v", err) return nil, ErrServiceUnavailable } if existsEmail { - log.Printf("[Auth] Email already exists: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Email already exists: %s", email) return nil, ErrEmailExists } // 检查邮件队列服务是否配置 if s.emailQueueService == nil { - log.Println("[Auth] Email queue service not configured") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Email queue service not configured") return nil, errors.New("email queue service not configured") } @@ -296,45 +304,56 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S } // 异步发送 - log.Printf("[Auth] Enqueueing verify code for: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Enqueueing verify code for: %s", email) if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil { - log.Printf("[Auth] Failed to enqueue: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue: %v", err) return nil, fmt.Errorf("enqueue verify code: %w", err) } - log.Printf("[Auth] Verify code enqueued successfully for: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Verify code enqueued successfully for: %s", email) return &SendVerifyCodeResult{ Countdown: 60, // 60秒倒计时 }, nil } +// VerifyTurnstileForRegister 在注册场景下验证 Turnstile。 +// 当邮箱验证开启且已提交验证码时,说明验证码发送阶段已完成 Turnstile 校验, +// 此处跳过二次校验,避免一次性 token 在注册提交时重复使用导致误报失败。 +func (s *AuthService) VerifyTurnstileForRegister(ctx context.Context, token, remoteIP, verifyCode string) error { + if s.IsEmailVerifyEnabled(ctx) && strings.TrimSpace(verifyCode) != "" { + logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verify flow detected, skip duplicate Turnstile check on register") + return nil + } + return s.VerifyTurnstile(ctx, token, remoteIP) +} + // VerifyTurnstile 验证Turnstile token func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error { required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required if required { if s.settingService == nil { - log.Println("[Auth] Turnstile required but settings service is not configured") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Turnstile required but settings service is not configured") return ErrTurnstileNotConfigured } enabled := s.settingService.IsTurnstileEnabled(ctx) secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != "" if !enabled || !secretConfigured { - log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured) + logger.LegacyPrintf("service.auth", "[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured) return ErrTurnstileNotConfigured } } if s.turnstileService == nil { if required { - log.Println("[Auth] Turnstile required but service not configured") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Turnstile required but service not configured") return ErrTurnstileNotConfigured } return nil // 服务未配置则跳过验证 } if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" { - log.Println("[Auth] Turnstile enabled but secret key not configured") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Turnstile enabled but secret key not configured") } return s.turnstileService.VerifyToken(ctx, token, remoteIP) @@ -373,7 +392,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string return "", nil, ErrInvalidCredentials } // 记录数据库错误但不暴露给用户 - log.Printf("[Auth] Database error during login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error during login: %v", err) return "", nil, ErrServiceUnavailable } @@ -426,7 +445,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username randomPassword, err := randomHexString(32) if err != nil { - log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err) return "", nil, ErrServiceUnavailable } hashedPassword, err := s.HashPassword(randomPassword) @@ -457,18 +476,19 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username // 并发场景:GetByEmail 与 Create 之间用户被创建。 user, err = s.userRepo.GetByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error getting user after conflict: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) return "", nil, ErrServiceUnavailable } } else { - log.Printf("[Auth] Database error creating oauth user: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) return "", nil, ErrServiceUnavailable } } else { user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) } } else { - log.Printf("[Auth] Database error during oauth login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) return "", nil, ErrServiceUnavailable } } @@ -481,7 +501,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username if user.Username == "" && username != "" { user.Username = username if err := s.userRepo.Update(ctx, user); err != nil { - log.Printf("[Auth] Failed to update username after oauth login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } @@ -523,7 +543,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema randomPassword, err := randomHexString(32) if err != nil { - log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err) return nil, nil, ErrServiceUnavailable } hashedPassword, err := s.HashPassword(randomPassword) @@ -552,18 +572,19 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema if errors.Is(err, ErrEmailExists) { user, err = s.userRepo.GetByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error getting user after conflict: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) return nil, nil, ErrServiceUnavailable } } else { - log.Printf("[Auth] Database error creating oauth user: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) return nil, nil, ErrServiceUnavailable } } else { user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) } } else { - log.Printf("[Auth] Database error during oauth login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) return nil, nil, ErrServiceUnavailable } } @@ -575,7 +596,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema if user.Username == "" && username != "" { user.Username = username if err := s.userRepo.Update(ctx, user); err != nil { - log.Printf("[Auth] Failed to update username after oauth login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } @@ -586,6 +607,23 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } +func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { + if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { + return + } + items := s.settingService.GetDefaultSubscriptions(ctx) + for _, item := range items { + if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + Notes: "auto assigned by default user subscriptions setting", + }); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) + } + } +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 @@ -715,7 +753,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) ( if errors.Is(err, ErrUserNotFound) { return "", ErrInvalidToken } - log.Printf("[Auth] Database error refreshing token: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error refreshing token: %v", err) return "", ErrServiceUnavailable } @@ -756,16 +794,16 @@ func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendB if err != nil { if errors.Is(err, ErrUserNotFound) { // Security: Log but don't reveal that user doesn't exist - log.Printf("[Auth] Password reset requested for non-existent email: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset requested for non-existent email: %s", email) return "", "", false } - log.Printf("[Auth] Database error checking email for password reset: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error checking email for password reset: %v", err) return "", "", false } // Check if user is active if !user.IsActive() { - log.Printf("[Auth] Password reset requested for inactive user: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset requested for inactive user: %s", email) return "", "", false } @@ -797,11 +835,11 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB } if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil { - log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to send password reset email to %s: %v", email, err) return nil // Silent success to prevent enumeration } - log.Printf("[Auth] Password reset email sent to: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset email sent to: %s", email) return nil } @@ -821,11 +859,11 @@ func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, fron } if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil { - log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue password reset email for %s: %v", email, err) return nil // Silent success to prevent enumeration } - log.Printf("[Auth] Password reset email enqueued for: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset email enqueued for: %s", email) return nil } @@ -852,7 +890,7 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo if errors.Is(err, ErrUserNotFound) { return ErrInvalidResetToken // Token was valid but user was deleted } - log.Printf("[Auth] Database error getting user for password reset: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user for password reset: %v", err) return ErrServiceUnavailable } @@ -872,17 +910,17 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo user.TokenVersion++ // Invalidate all existing tokens if err := s.userRepo.Update(ctx, user); err != nil { - log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err) + logger.LegacyPrintf("service.auth", "[Auth] Database error updating password for user %d: %v", user.ID, err) return ErrServiceUnavailable } // Also revoke all refresh tokens for this user if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil { - log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err) // Don't return error - password was already changed successfully } - log.Printf("[Auth] Password reset successful for user: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset successful for user: %s", email) return nil } @@ -961,13 +999,13 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami // 添加到用户Token集合 if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil { - log.Printf("[Auth] Failed to add token to user set: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to add token to user set: %v", err) // 不影响主流程 } // 添加到家族Token集合 if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil { - log.Printf("[Auth] Failed to add token to family set: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to add token to family set: %v", err) // 不影响主流程 } @@ -994,10 +1032,10 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) if err != nil { if errors.Is(err, ErrRefreshTokenNotFound) { // Token不存在,可能是已被使用(Token轮转)或已过期 - log.Printf("[Auth] Refresh token not found, possible reuse attack") + logger.LegacyPrintf("service.auth", "[Auth] Refresh token not found, possible reuse attack") return nil, ErrRefreshTokenInvalid } - log.Printf("[Auth] Error getting refresh token: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Error getting refresh token: %v", err) return nil, ErrServiceUnavailable } @@ -1016,7 +1054,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) return nil, ErrRefreshTokenInvalid } - log.Printf("[Auth] Database error getting user for token refresh: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user for token refresh: %v", err) return nil, ErrServiceUnavailable } @@ -1036,7 +1074,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) // Token轮转:立即使旧Token失效 if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil { - log.Printf("[Auth] Failed to delete old refresh token: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to delete old refresh token: %v", err) // 继续处理,不影响主流程 } diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index f1685be5..1999e759 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -56,6 +56,21 @@ type emailCacheStub struct { err error } +type defaultSubscriptionAssignerStub struct { + calls []AssignSubscriptionInput + err error +} + +func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { + if input != nil { + s.calls = append(s.calls, *input) + } + if s.err != nil { + return nil, false, s.err + } + return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil +} + func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) { if s.err != nil { return nil, s.err @@ -123,6 +138,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E nil, nil, nil, // promoService + nil, // defaultSubAssigner ) } @@ -315,3 +331,89 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) { require.NotEmpty(t, newToken) }) } + +func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T) { + service := newAuthService(&userRepoStub{}, nil, nil) + service.cfg.JWT.ExpireHour = 24 + service.cfg.JWT.AccessTokenExpireMinutes = 0 + + require.Equal(t, 24*3600, service.GetAccessTokenExpiresIn()) +} + +func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) { + service := newAuthService(&userRepoStub{}, nil, nil) + service.cfg.JWT.ExpireHour = 24 + service.cfg.JWT.AccessTokenExpireMinutes = 90 + + require.Equal(t, 90*60, service.GetAccessTokenExpiresIn()) +} + +func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) { + service := newAuthService(&userRepoStub{}, nil, nil) + service.cfg.JWT.ExpireHour = 24 + service.cfg.JWT.AccessTokenExpireMinutes = 0 + + user := &User{ + ID: 1, + Email: "test@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + + token, err := service.GenerateToken(user) + require.NoError(t, err) + + claims, err := service.ValidateToken(token) + require.NoError(t, err) + require.NotNil(t, claims) + require.NotNil(t, claims.IssuedAt) + require.NotNil(t, claims.ExpiresAt) + + require.WithinDuration(t, claims.IssuedAt.Time.Add(24*time.Hour), claims.ExpiresAt.Time, 2*time.Second) +} + +func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) { + service := newAuthService(&userRepoStub{}, nil, nil) + service.cfg.JWT.ExpireHour = 24 + service.cfg.JWT.AccessTokenExpireMinutes = 90 + + user := &User{ + ID: 2, + Email: "test2@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + + token, err := service.GenerateToken(user) + require.NoError(t, err) + + claims, err := service.ValidateToken(token) + require.NoError(t, err) + require.NotNil(t, claims) + require.NotNil(t, claims.IssuedAt) + require.NotNil(t, claims.ExpiresAt) + + require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second) +} + +func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { + repo := &userRepoStub{nextID: 42} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "default-sub@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Len(t, assigner.calls, 2) + require.Equal(t, int64(42), assigner.calls[0].UserID) + require.Equal(t, int64(11), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) + require.Equal(t, int64(12), assigner.calls[1].GroupID) + require.Equal(t, 7, assigner.calls[1].ValidityDays) +} diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go new file mode 100644 index 00000000..36cb1e06 --- /dev/null +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type turnstileVerifierSpy struct { + called int + lastToken string + result *TurnstileVerifyResponse + err error +} + +func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) { + s.called++ + s.lastToken = token + if s.err != nil { + return nil, s.err + } + if s.result != nil { + return s.result, nil + } + return &TurnstileVerifyResponse{Success: true}, nil +} + +func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService { + cfg := &config.Config{ + Server: config.ServerConfig{ + Mode: "release", + }, + Turnstile: config.TurnstileConfig{ + Required: true, + }, + } + + settingService := NewSettingService(&settingRepoStub{values: settings}, cfg) + turnstileService := NewTurnstileService(settingService, verifier) + + return NewAuthService( + &userRepoStub{}, + nil, // redeemRepo + nil, // refreshTokenCache + cfg, + settingService, + nil, // emailService + turnstileService, + nil, // emailQueueService + nil, // promoService + nil, // defaultSubAssigner + ) +} + +func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) { + verifier := &turnstileVerifierSpy{} + service := newAuthServiceForRegisterTurnstileTest(map[string]string{ + SettingKeyEmailVerifyEnabled: "true", + SettingKeyTurnstileEnabled: "true", + SettingKeyTurnstileSecretKey: "secret", + SettingKeyRegistrationEnabled: "true", + }, verifier) + + err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456") + require.NoError(t, err) + require.Equal(t, 0, verifier.called) +} + +func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) { + verifier := &turnstileVerifierSpy{} + service := newAuthServiceForRegisterTurnstileTest(map[string]string{ + SettingKeyEmailVerifyEnabled: "true", + SettingKeyTurnstileEnabled: "true", + SettingKeyTurnstileSecretKey: "secret", + }, verifier) + + err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "") + require.ErrorIs(t, err, ErrTurnstileVerificationFailed) +} + +func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) { + verifier := &turnstileVerifierSpy{} + service := newAuthServiceForRegisterTurnstileTest(map[string]string{ + SettingKeyEmailVerifyEnabled: "false", + SettingKeyTurnstileEnabled: "true", + SettingKeyTurnstileSecretKey: "secret", + }, verifier) + + err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456") + require.NoError(t, err) + require.Equal(t, 1, verifier.called) + require.Equal(t, "turnstile-token", verifier.lastToken) +} diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index c09cafb9..1a76f5f6 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -3,13 +3,15 @@ package service import ( "context" "fmt" - "log" + "strconv" "sync" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "golang.org/x/sync/singleflight" ) // 错误定义 @@ -58,6 +60,7 @@ const ( cacheWriteBufferSize = 1000 // 任务队列缓冲大小 cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔 + balanceLoadTimeout = 3 * time.Second ) // cacheWriteTask 缓存写入任务 @@ -82,6 +85,9 @@ type BillingCacheService struct { cacheWriteChan chan cacheWriteTask cacheWriteWg sync.WaitGroup cacheWriteStopOnce sync.Once + cacheWriteMu sync.RWMutex + stopped atomic.Bool + balanceLoadSF singleflight.Group // 丢弃日志节流计数器(减少高负载下日志噪音) cacheWriteDropFullCount uint64 cacheWriteDropFullLastLog int64 @@ -105,35 +111,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo // Stop 关闭缓存写入工作池 func (s *BillingCacheService) Stop() { s.cacheWriteStopOnce.Do(func() { - if s.cacheWriteChan == nil { + s.stopped.Store(true) + + s.cacheWriteMu.Lock() + ch := s.cacheWriteChan + if ch != nil { + close(ch) + } + s.cacheWriteMu.Unlock() + + if ch == nil { return } - close(s.cacheWriteChan) s.cacheWriteWg.Wait() - s.cacheWriteChan = nil + + s.cacheWriteMu.Lock() + if s.cacheWriteChan == ch { + s.cacheWriteChan = nil + } + s.cacheWriteMu.Unlock() }) } func (s *BillingCacheService) startCacheWriteWorkers() { - s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize) + ch := make(chan cacheWriteTask, cacheWriteBufferSize) + s.cacheWriteChan = ch for i := 0; i < cacheWriteWorkerCount; i++ { s.cacheWriteWg.Add(1) - go s.cacheWriteWorker() + go s.cacheWriteWorker(ch) } } // enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。 func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) { - if s.cacheWriteChan == nil { + if s.stopped.Load() { + s.logCacheWriteDrop(task, "closed") return false } - defer func() { - if recovered := recover(); recovered != nil { - // 队列已关闭时可能触发 panic,记录后静默失败。 - s.logCacheWriteDrop(task, "closed") - enqueued = false - } - }() + + s.cacheWriteMu.RLock() + defer s.cacheWriteMu.RUnlock() + + if s.cacheWriteChan == nil { + s.logCacheWriteDrop(task, "closed") + return false + } + select { case s.cacheWriteChan <- task: return true @@ -144,9 +167,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b } } -func (s *BillingCacheService) cacheWriteWorker() { +func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) { defer s.cacheWriteWg.Done() - for task := range s.cacheWriteChan { + for task := range ch { ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) switch task.kind { case cacheWriteSetBalance: @@ -156,13 +179,13 @@ func (s *BillingCacheService) cacheWriteWorker() { case cacheWriteUpdateSubscriptionUsage: if s.cache != nil { if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil { - log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err) } } case cacheWriteDeductBalance: if s.cache != nil { if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil { - log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err) } } } @@ -216,7 +239,7 @@ func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason stri if dropped == 0 { return } - log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)", + logger.LegacyPrintf("service.billing_cache", "Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)", reason, dropped, cacheWriteDropLogInterval, @@ -243,19 +266,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) return balance, nil } - // 缓存未命中,从数据库读取 - balance, err = s.getUserBalanceFromDB(ctx, userID) + // 缓存未命中:singleflight 合并同一 userID 的并发回源请求。 + value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) { + loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout) + defer cancel() + + balance, err := s.getUserBalanceFromDB(loadCtx, userID) + if err != nil { + return nil, err + } + + // 异步建立缓存 + _ = s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteSetBalance, + userID: userID, + balance: balance, + }) + return balance, nil + }) if err != nil { return 0, err } - - // 异步建立缓存 - _ = s.enqueueCacheWrite(cacheWriteTask{ - kind: cacheWriteSetBalance, - userID: userID, - balance: balance, - }) - + balance, ok := value.(float64) + if !ok { + return 0, fmt.Errorf("unexpected balance type: %T", value) + } return balance, nil } @@ -274,7 +309,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, return } if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil { - log.Printf("Warning: set balance cache failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: set balance cache failed for user %d: %v", userID, err) } } @@ -302,7 +337,7 @@ func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) { ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) defer cancel() if err := s.DeductBalanceCache(ctx, userID, amount); err != nil { - log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache fallback failed for user %d: %v", userID, err) } } @@ -312,7 +347,7 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID return nil } if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil { - log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: invalidate balance cache failed for user %d: %v", userID, err) return err } return nil @@ -396,7 +431,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, return } if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil { - log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err) } } @@ -425,7 +460,7 @@ func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64 ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) defer cancel() if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil { - log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err) } } @@ -435,7 +470,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID return nil } if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil { - log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err) return err } return nil @@ -474,7 +509,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI if s.circuitBreaker != nil { s.circuitBreaker.OnFailure(err) } - log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing balance check failed for user %d: %v", userID, err) return ErrBillingServiceUnavailable.WithCause(err) } if s.circuitBreaker != nil { @@ -496,7 +531,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, if s.circuitBreaker != nil { s.circuitBreaker.OnFailure(err) } - log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err) return ErrBillingServiceUnavailable.WithCause(err) } if s.circuitBreaker != nil { @@ -585,7 +620,7 @@ func (b *billingCircuitBreaker) Allow() bool { } b.state = billingCircuitHalfOpen b.halfOpenRemaining = b.halfOpenRequests - log.Printf("ALERT: billing circuit breaker entering half-open state") + logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker entering half-open state") fallthrough case billingCircuitHalfOpen: if b.halfOpenRemaining <= 0 { @@ -612,7 +647,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) { b.state = billingCircuitOpen b.openedAt = time.Now() b.halfOpenRemaining = 0 - log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after half-open failure: %v", err) return default: b.failures++ @@ -620,7 +655,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) { b.state = billingCircuitOpen b.openedAt = time.Now() b.halfOpenRemaining = 0 - log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err) } } } @@ -641,9 +676,9 @@ func (b *billingCircuitBreaker) OnSuccess() { // 只有状态真正发生变化时才记录日志 if previousState != billingCircuitClosed { - log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState)) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState)) } else if previousFailures > 0 { - log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures) + logger.LegacyPrintf("service.billing_cache", "INFO: billing circuit breaker failures reset from %d", previousFailures) } } diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go new file mode 100644 index 00000000..1b12c402 --- /dev/null +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -0,0 +1,115 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type billingCacheMissStub struct { + setBalanceCalls atomic.Int64 +} + +func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) { + return 0, errors.New("cache miss") +} + +func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error { + s.setBalanceCalls.Add(1) + return nil +} + +func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { + return nil +} + +func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error { + return nil +} + +func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) { + return nil, errors.New("cache miss") +} + +func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error { + return nil +} + +func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { + return nil +} + +func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { + return nil +} + +type balanceLoadUserRepoStub struct { + mockUserRepo + calls atomic.Int64 + delay time.Duration + balance float64 +} + +func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) { + s.calls.Add(1) + if s.delay > 0 { + select { + case <-time.After(s.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + return &User{ID: id, Balance: s.balance}, nil +} + +func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { + cache := &billingCacheMissStub{} + userRepo := &balanceLoadUserRepoStub{ + delay: 80 * time.Millisecond, + balance: 12.34, + } + svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{}) + t.Cleanup(svc.Stop) + + const goroutines = 16 + start := make(chan struct{}) + var wg sync.WaitGroup + errCh := make(chan error, goroutines) + balCh := make(chan float64, goroutines) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + bal, err := svc.GetUserBalance(context.Background(), 99) + errCh <- err + balCh <- bal + }() + } + + close(start) + wg.Wait() + close(errCh) + close(balCh) + + for err := range errCh { + require.NoError(t, err) + } + for bal := range balCh { + require.Equal(t, 12.34, bal) + } + + require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并") + require.Eventually(t, func() bool { + return cache.setBalanceCalls.Load() >= 1 + }, time.Second, 10*time.Millisecond) +} diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go index 445d5319..4e5f50e2 100644 --- a/backend/internal/service/billing_cache_service_test.go +++ b/backend/internal/service/billing_cache_service_test.go @@ -73,3 +73,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { return atomic.LoadInt64(&cache.subscriptionUpdates) > 0 }, 2*time.Second, 10*time.Millisecond) } + +func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { + cache := &billingCacheWorkerStub{} + svc := NewBillingCacheService(cache, nil, nil, &config.Config{}) + svc.Stop() + + enqueued := svc.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteDeductBalance, + userID: 1, + amount: 1, + }) + require.False(t, enqueued) +} diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index db5a9708..6abd1e53 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -31,8 +31,8 @@ type ModelPricing struct { OutputPricePerToken float64 // 每token输出价格 (USD) CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) - CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退 - CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退 + CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD) + CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD) SupportsCacheBreakdown bool // 是否支持详细的缓存分类 } @@ -133,6 +133,18 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok SupportsCacheBreakdown: false, } + + // Claude 4.6 Opus (与4.5同价) + s.fallbackPrices["claude-opus-4.6"] = s.fallbackPrices["claude-opus-4.5"] + + // Gemini 3.1 Pro + s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{ + InputPricePerToken: 2e-6, // $2 per MTok + OutputPricePerToken: 12e-6, // $12 per MTok + CacheCreationPricePerToken: 2e-6, // $2 per MTok + CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok + SupportsCacheBreakdown: false, + } } // getFallbackPricing 根据模型系列获取回退价格 @@ -141,6 +153,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { // 按模型系列匹配 if strings.Contains(modelLower, "opus") { + if strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6") { + return s.fallbackPrices["claude-opus-4.6"] + } if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") { return s.fallbackPrices["claude-opus-4.5"] } @@ -158,6 +173,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { } return s.fallbackPrices["claude-3-haiku"] } + if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") { + return s.fallbackPrices["gemini-3.1-pro"] + } // 默认使用Sonnet价格 return s.fallbackPrices["claude-sonnet-4"] @@ -172,12 +190,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { if s.pricingService != nil { litellmPricing := s.pricingService.GetModelPricing(model) if litellmPricing != nil { + // 启用 5m/1h 分类计费的条件: + // 1. 存在 1h 价格 + // 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费) + price5m := litellmPricing.CacheCreationInputTokenCost + price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr + enableBreakdown := price1h > 0 && price1h > price5m return &ModelPricing{ InputPricePerToken: litellmPricing.InputCostPerToken, OutputPricePerToken: litellmPricing.OutputCostPerToken, CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, - SupportsCacheBreakdown: false, + CacheCreation5mPrice: price5m, + CacheCreation1hPrice: price1h, + SupportsCacheBreakdown: enableBreakdown, }, nil } } @@ -209,9 +235,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul // 计算缓存费用 if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { - // 支持详细缓存分类的模型(5分钟/1小时缓存) - breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice + - float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice + // 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token) + if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 { + // API 未返回 ephemeral 明细,回退到全部按 5m 单价计费 + breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice + } else { + breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice + + float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice + } } else { // 标准缓存创建价格(per-token) breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken @@ -280,10 +311,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage // 范围内部分:正常计费 inRangeTokens := UsageTokens{ - InputTokens: inRangeInputTokens, - OutputTokens: tokens.OutputTokens, // 输出只算一次 - CacheCreationTokens: tokens.CacheCreationTokens, - CacheReadTokens: inRangeCacheTokens, + InputTokens: inRangeInputTokens, + OutputTokens: tokens.OutputTokens, // 输出只算一次 + CacheCreationTokens: tokens.CacheCreationTokens, + CacheReadTokens: inRangeCacheTokens, + CacheCreation5mTokens: tokens.CacheCreation5mTokens, + CacheCreation1hTokens: tokens.CacheCreation1hTokens, } inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) if err != nil { @@ -297,7 +330,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage } outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier) if err != nil { - return inRangeCost, nil // 出错时返回范围内成本 + return inRangeCost, fmt.Errorf("out-range cost: %w", err) } // 合并成本 @@ -373,6 +406,14 @@ type ImagePriceConfig struct { Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值) } +// SoraPriceConfig Sora 按次计费配置 +type SoraPriceConfig struct { + ImagePrice360 *float64 + ImagePrice540 *float64 + VideoPricePerRequest *float64 + VideoPricePerRequestHD *float64 +} + // CalculateImageCost 计算图片生成费用 // model: 请求的模型名称(用于获取 LiteLLM 默认价格) // imageSize: 图片尺寸 "1K", "2K", "4K" @@ -402,6 +443,65 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag } } +// CalculateSoraImageCost 计算 Sora 图片按次费用 +func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { + if imageCount <= 0 { + return &CostBreakdown{} + } + + unitPrice := 0.0 + if groupConfig != nil { + switch imageSize { + case "540": + if groupConfig.ImagePrice540 != nil { + unitPrice = *groupConfig.ImagePrice540 + } + default: + if groupConfig.ImagePrice360 != nil { + unitPrice = *groupConfig.ImagePrice360 + } + } + } + + totalCost := unitPrice * float64(imageCount) + if rateMultiplier <= 0 { + rateMultiplier = 1.0 + } + actualCost := totalCost * rateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + } +} + +// CalculateSoraVideoCost 计算 Sora 视频按次费用 +func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { + unitPrice := 0.0 + if groupConfig != nil { + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "sora2pro-hd") { + if groupConfig.VideoPricePerRequestHD != nil { + unitPrice = *groupConfig.VideoPricePerRequestHD + } + } + if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil { + unitPrice = *groupConfig.VideoPricePerRequest + } + } + + totalCost := unitPrice + if rateMultiplier <= 0 { + rateMultiplier = 1.0 + } + actualCost := totalCost * rateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + } +} + // getImageUnitPrice 获取图片单价 func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 { // 优先使用分组配置的价格 @@ -443,7 +543,10 @@ func (s *BillingService) getDefaultImagePrice(model string, imageSize string) fl basePrice = 0.134 } - // 4K 尺寸翻倍 + // 2K 尺寸 1.5 倍,4K 尺寸翻倍 + if imageSize == "2K" { + return basePrice * 1.5 + } if imageSize == "4K" { return basePrice * 2 } diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go index 18a6b74d..fa90f6bb 100644 --- a/backend/internal/service/billing_service_image_test.go +++ b/backend/internal/service/billing_service_image_test.go @@ -12,14 +12,14 @@ import ( func TestCalculateImageCost_DefaultPricing(t *testing.T) { svc := &BillingService{} // pricingService 为 nil,使用硬编码默认值 - // 2K 尺寸,默认价格 $0.134 + // 2K 尺寸,默认价格 $0.134 * 1.5 = $0.201 cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) - require.InDelta(t, 0.134, cost.ActualCost, 0.0001) + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) + require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 多张图片 cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 3, nil, 1.0) - require.InDelta(t, 0.402, cost.TotalCost, 0.0001) + require.InDelta(t, 0.603, cost.TotalCost, 0.0001) } // TestCalculateImageCost_GroupCustomPricing 测试分组自定义价格 @@ -63,13 +63,13 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) { // 费率倍数 1.5x cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) // TotalCost 不变 - require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // ActualCost = 0.134 * 1.5 + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5 + require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5 // 费率倍数 2.0x cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 2, nil, 2.0) - require.InDelta(t, 0.268, cost.TotalCost, 0.0001) - require.InDelta(t, 0.536, cost.ActualCost, 0.0001) + require.InDelta(t, 0.402, cost.TotalCost, 0.0001) + require.InDelta(t, 0.804, cost.ActualCost, 0.0001) } // TestCalculateImageCost_ZeroCount 测试 imageCount=0 @@ -95,8 +95,8 @@ func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) { svc := &BillingService{} cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) - require.InDelta(t, 0.134, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) + require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 } // TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格 @@ -127,9 +127,9 @@ func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) { cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, groupConfig, 1.0) require.InDelta(t, 0.10, cost.TotalCost, 0.0001) - // 2K 回退默认价格 $0.134 + // 2K 回退默认价格 $0.201 (1.5倍) cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // 4K 回退默认价格 $0.268 (翻倍) cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0) @@ -140,10 +140,10 @@ func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) { func TestGetDefaultImagePrice_FallbackHardcoded(t *testing.T) { svc := &BillingService{} // pricingService 为 nil - // 1K 和 2K 使用相同的默认价格 $0.134 + // 1K 默认价格 $0.134,2K 默认价格 $0.201 (1.5倍) cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, nil, 1.0) require.InDelta(t, 0.134, cost.TotalCost, 0.0001) cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) } diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go new file mode 100644 index 00000000..5eb278f6 --- /dev/null +++ b/backend/internal/service/billing_service_test.go @@ -0,0 +1,437 @@ +//go:build unit + +package service + +import ( + "math" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func newTestBillingService() *BillingService { + return NewBillingService(&config.Config{}, nil) +} + +func TestCalculateCost_BasicComputation(t *testing.T) { + svc := newTestBillingService() + + // 使用 claude-sonnet-4 的回退价格:Input $3/MTok, Output $15/MTok + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + // 1000 * 3e-6 = 0.003, 500 * 15e-6 = 0.0075 + expectedInput := 1000 * 3e-6 + expectedOutput := 500 * 15e-6 + require.InDelta(t, expectedInput, cost.InputCost, 1e-10) + require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10) +} + +func TestCalculateCost_WithCacheTokens(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + CacheCreationTokens: 2000, + CacheReadTokens: 3000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + expectedCacheCreation := 2000 * 3.75e-6 + expectedCacheRead := 3000 * 0.3e-6 + require.InDelta(t, expectedCacheCreation, cost.CacheCreationCost, 1e-10) + require.InDelta(t, expectedCacheRead, cost.CacheReadCost, 1e-10) + + expectedTotal := cost.InputCost + cost.OutputCost + expectedCacheCreation + expectedCacheRead + require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10) +} + +func TestCalculateCost_RateMultiplier(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + cost1x, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + cost2x, err := svc.CalculateCost("claude-sonnet-4", tokens, 2.0) + require.NoError(t, err) + + // TotalCost 不受倍率影响,ActualCost 翻倍 + require.InDelta(t, cost1x.TotalCost, cost2x.TotalCost, 1e-10) + require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10) +} + +func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000} + + costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0) + require.NoError(t, err) + + costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) +} + +func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000} + + costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0) + require.NoError(t, err) + + costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) +} + +func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) { + svc := newTestBillingService() + + tests := []struct { + model string + expectedInput float64 + }{ + {"claude-opus-4.5-20250101", 5e-6}, + {"claude-3-opus-20240229", 15e-6}, + {"claude-sonnet-4-20250514", 3e-6}, + {"claude-3-5-sonnet-20241022", 3e-6}, + {"claude-3-5-haiku-20241022", 1e-6}, + {"claude-3-haiku-20240307", 0.25e-6}, + } + + for _, tt := range tests { + pricing, err := svc.GetModelPricing(tt.model) + require.NoError(t, err, "模型 %s", tt.model) + require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12, "模型 %s 输入价格", tt.model) + } +} + +func TestGetModelPricing_CaseInsensitive(t *testing.T) { + svc := newTestBillingService() + + p1, err := svc.GetModelPricing("Claude-Sonnet-4") + require.NoError(t, err) + + p2, err := svc.GetModelPricing("claude-sonnet-4") + require.NoError(t, err) + + require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken) +} + +func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) { + svc := newTestBillingService() + + // 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格 + pricing, err := svc.GetModelPricing("claude-unknown-model") + require.NoError(t, err) + require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) +} + +func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 50000, + OutputTokens: 1000, + CacheReadTokens: 100000, + } + // 总输入 150k < 200k 阈值,应走正常计费 + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_AboveThreshold_CacheExceedsThreshold(t *testing.T) { + svc := newTestBillingService() + + // 缓存 210k + 输入 10k = 220k > 200k 阈值 + // 缓存已超阈值:范围内 200k 缓存,范围外 10k 缓存 + 10k 输入 + tokens := UsageTokens{ + InputTokens: 10000, + OutputTokens: 1000, + CacheReadTokens: 210000, + } + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + // 范围内:200k cache + 0 input + 1k output + inRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{ + InputTokens: 0, + OutputTokens: 1000, + CacheReadTokens: 200000, + }, 1.0) + + // 范围外:10k cache + 10k input,倍率 2.0 + outRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{ + InputTokens: 10000, + CacheReadTokens: 10000, + }, 2.0) + + require.InDelta(t, inRange.ActualCost+outRange.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_AboveThreshold_CacheBelowThreshold(t *testing.T) { + svc := newTestBillingService() + + // 缓存 100k + 输入 150k = 250k > 200k 阈值 + // 缓存未超阈值:范围内 100k 缓存 + 100k 输入,范围外 50k 输入 + tokens := UsageTokens{ + InputTokens: 150000, + OutputTokens: 1000, + CacheReadTokens: 100000, + } + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + require.True(t, cost.ActualCost > 0, "费用应大于 0") + + // 正常费用不含长上下文 + normalCost, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.True(t, cost.ActualCost > normalCost.ActualCost, "长上下文费用应高于正常费用") +} + +func TestCalculateCostWithLongContext_DisabledThreshold(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0} + + // threshold <= 0 应禁用长上下文计费 + cost1, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 0, 2.0) + require.NoError(t, err) + + cost2, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, cost2.ActualCost, cost1.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_ExtraMultiplierLessEqualOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 300000} + + // extraMultiplier <= 1 应禁用长上下文计费 + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 1.0) + require.NoError(t, err) + + normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateImageCost(t *testing.T) { + svc := newTestBillingService() + + price := 0.134 + cfg := &ImagePriceConfig{Price1K: &price} + cost := svc.CalculateImageCost("gpt-image-1", "1K", 3, cfg, 1.0) + + require.InDelta(t, 0.134*3, cost.TotalCost, 1e-10) + require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10) +} + +func TestCalculateSoraVideoCost(t *testing.T) { + svc := newTestBillingService() + + price := 0.5 + cfg := &SoraPriceConfig{VideoPricePerRequest: &price} + cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0) + + require.InDelta(t, 0.5, cost.TotalCost, 1e-10) +} + +func TestCalculateSoraVideoCost_HDModel(t *testing.T) { + svc := newTestBillingService() + + hdPrice := 1.0 + normalPrice := 0.5 + cfg := &SoraPriceConfig{ + VideoPricePerRequest: &normalPrice, + VideoPricePerRequestHD: &hdPrice, + } + cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0) + require.InDelta(t, 1.0, cost.TotalCost, 1e-10) +} + +func TestIsModelSupported(t *testing.T) { + svc := newTestBillingService() + + require.True(t, svc.IsModelSupported("claude-sonnet-4")) + require.True(t, svc.IsModelSupported("Claude-Opus-4.5")) + require.True(t, svc.IsModelSupported("claude-3-haiku")) + require.False(t, svc.IsModelSupported("gpt-4o")) + require.False(t, svc.IsModelSupported("gemini-pro")) +} + +func TestCalculateCost_ZeroTokens(t *testing.T) { + svc := newTestBillingService() + + cost, err := svc.CalculateCost("claude-sonnet-4", UsageTokens{}, 1.0) + require.NoError(t, err) + require.Equal(t, 0.0, cost.TotalCost) + require.Equal(t, 0.0, cost.ActualCost) +} + +func TestCalculateCostWithConfig(t *testing.T) { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 1.5 + svc := NewBillingService(cfg, nil) + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + cost, err := svc.CalculateCostWithConfig("claude-sonnet-4", tokens) + require.NoError(t, err) + + expected, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.5) + require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithConfig_ZeroMultiplier(t *testing.T) { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 0 + svc := NewBillingService(cfg, nil) + + tokens := UsageTokens{InputTokens: 1000} + cost, err := svc.CalculateCostWithConfig("claude-sonnet-4", tokens) + require.NoError(t, err) + + // 倍率 <=0 时默认 1.0 + expected, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10) +} + +func TestGetEstimatedCost(t *testing.T) { + svc := newTestBillingService() + + est, err := svc.GetEstimatedCost("claude-sonnet-4", 1000, 500) + require.NoError(t, err) + require.True(t, est > 0) +} + +func TestListSupportedModels(t *testing.T) { + svc := newTestBillingService() + + models := svc.ListSupportedModels() + require.NotEmpty(t, models) + require.GreaterOrEqual(t, len(models), 6) +} + +func TestGetPricingServiceStatus_NilService(t *testing.T) { + svc := newTestBillingService() + + status := svc.GetPricingServiceStatus() + require.NotNil(t, status) + require.Equal(t, "using fallback", status["last_updated"]) +} + +func TestForceUpdatePricing_NilService(t *testing.T) { + svc := newTestBillingService() + + err := svc.ForceUpdatePricing() + require.Error(t, err) + require.Contains(t, err.Error(), "not initialized") +} + +func TestCalculateSoraImageCost(t *testing.T) { + svc := newTestBillingService() + + price360 := 0.05 + price540 := 0.08 + cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540} + + cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0) + require.InDelta(t, 0.10, cost.TotalCost, 1e-10) + + cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0) + require.InDelta(t, 0.08, cost540.TotalCost, 1e-10) + require.InDelta(t, 0.16, cost540.ActualCost, 1e-10) +} + +func TestCalculateSoraImageCost_ZeroCount(t *testing.T) { + svc := newTestBillingService() + cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0) + require.Equal(t, 0.0, cost.TotalCost) +} + +func TestCalculateSoraVideoCost_NilConfig(t *testing.T) { + svc := newTestBillingService() + cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0) + require.Equal(t, 0.0, cost.TotalCost) +} + +func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) { + // 使用空的 fallback prices 让 GetModelPricing 失败 + svc := &BillingService{ + cfg: &config.Config{}, + fallbackPrices: make(map[string]*ModelPricing), + } + + tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0} + _, err := svc.CalculateCostWithLongContext("unknown-model", tokens, 1.0, 200000, 2.0) + require.Error(t, err) + require.Contains(t, err.Error(), "pricing not found") +} + +func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) { + svc := &BillingService{ + cfg: &config.Config{}, + fallbackPrices: map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 3e-6, + OutputPricePerToken: 15e-6, + SupportsCacheBreakdown: true, + CacheCreation5mPrice: 4e-6, // per token + CacheCreation1hPrice: 5e-6, // per token + }, + }, + } + + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + CacheCreation5mTokens: 100000, + CacheCreation1hTokens: 50000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + expected5m := float64(tokens.CacheCreation5mTokens) * 4e-6 + expected1h := float64(tokens.CacheCreation1hTokens) * 5e-6 + require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10) +} + +func TestCalculateCost_LargeTokenCount(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 1_000_000, + OutputTokens: 1_000_000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + // Input: 1M * 3e-6 = $3, Output: 1M * 15e-6 = $15 + require.InDelta(t, 3.0, cost.InputCost, 1e-6) + require.InDelta(t, 15.0, cost.OutputCost, 1e-6) + require.False(t, math.IsNaN(cost.TotalCost)) + require.False(t, math.IsInf(cost.TotalCost, 0)) +} diff --git a/backend/internal/service/claude_code_detection_test.go b/backend/internal/service/claude_code_detection_test.go new file mode 100644 index 00000000..ff7ad7f4 --- /dev/null +++ b/backend/internal/service/claude_code_detection_test.go @@ -0,0 +1,282 @@ +//go:build unit + +package service + +import ( + "context" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func newTestValidator() *ClaudeCodeValidator { + return NewClaudeCodeValidator() +} + +// validClaudeCodeBody 构造一个完整有效的 Claude Code 请求体 +func validClaudeCodeBody() map[string]any { + return map[string]any{ + "model": "claude-sonnet-4-20250514", + "system": []any{ + map[string]any{ + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + }, + }, + "metadata": map[string]any{ + "user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_" + "12345678-1234-1234-1234-123456789abc", + }, + } +} + +func TestValidate_ClaudeCLIUserAgent(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + ua string + want bool + }{ + {"标准版本号", "claude-cli/1.0.0", true}, + {"多位版本号", "claude-cli/12.34.56", true}, + {"大写开头", "Claude-CLI/1.0.0", true}, + {"非 claude-cli", "curl/7.64.1", false}, + {"空 User-Agent", "", false}, + {"部分匹配", "not-claude-cli/1.0.0", false}, + {"缺少版本号", "claude-cli/", false}, + {"版本格式不对", "claude-cli/1.0", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, v.ValidateUserAgent(tt.ua), "UA: %q", tt.ua) + }) + } +} + +func TestValidate_NonMessagesPath_UAOnly(t *testing.T) { + v := newTestValidator() + + // 非 messages 路径只检查 UA + req := httptest.NewRequest("GET", "/v1/models", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + + result := v.Validate(req, nil) + require.True(t, result, "非 messages 路径只需 UA 匹配") +} + +func TestValidate_NonMessagesPath_InvalidUA(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("GET", "/v1/models", nil) + req.Header.Set("User-Agent", "curl/7.64.1") + + result := v.Validate(req, nil) + require.False(t, result, "UA 不匹配时应返回 false") +} + +func TestValidate_MessagesPath_FullValid(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15") + req.Header.Set("anthropic-version", "2023-06-01") + + result := v.Validate(req, validClaudeCodeBody()) + require.True(t, result, "完整有效请求应通过") +} + +func TestValidate_MessagesPath_MissingHeaders(t *testing.T) { + v := newTestValidator() + body := validClaudeCodeBody() + + tests := []struct { + name string + missingHeader string + }{ + {"缺少 X-App", "X-App"}, + {"缺少 anthropic-beta", "anthropic-beta"}, + {"缺少 anthropic-version", "anthropic-version"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Del(tt.missingHeader) + + result := v.Validate(req, body) + require.False(t, result, "缺少 %s 应返回 false", tt.missingHeader) + }) + } +} + +func TestValidate_MessagesPath_InvalidMetadataUserID(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + metadata map[string]any + }{ + {"缺少 metadata", nil}, + {"缺少 user_id", map[string]any{"other": "value"}}, + {"空 user_id", map[string]any{"user_id": ""}}, + {"格式错误", map[string]any{"user_id": "invalid-format"}}, + {"hex 长度不足", map[string]any{"user_id": "user_abc_account__session_uuid"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{ + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + }, + }, + } + if tt.metadata != nil { + body["metadata"] = tt.metadata + } + + result := v.Validate(req, body) + require.False(t, result, "metadata.user_id: %v", tt.metadata) + }) + } +} + +func TestValidate_MessagesPath_InvalidSystemPrompt(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{ + "type": "text", + "text": "Generate JSON data for testing database migrations.", + }, + }, + "metadata": map[string]any{ + "user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_12345678-1234-1234-1234-123456789abc", + }, + } + + result := v.Validate(req, body) + require.False(t, result, "无关系统提示词应返回 false") +} + +func TestValidate_MaxTokensOneHaikuBypass(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + // 不设置 X-App 等头,通过 context 标记为 haiku 探测请求 + ctx := context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true) + req = req.WithContext(ctx) + + // 即使 body 不包含 system prompt,也应通过 + result := v.Validate(req, map[string]any{"model": "claude-3-haiku", "max_tokens": 1}) + require.True(t, result, "max_tokens=1+haiku 探测请求应绕过严格验证") +} + +func TestSystemPromptSimilarity(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + prompt string + want bool + }{ + {"精确匹配", "You are Claude Code, Anthropic's official CLI for Claude.", true}, + {"带多余空格", "You are Claude Code, Anthropic's official CLI for Claude.", true}, + {"Agent SDK 模板", "You are a Claude agent, built on Anthropic's Claude Agent SDK.", true}, + {"文件搜索专家模板", "You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.", true}, + {"对话摘要模板", "You are a helpful AI assistant tasked with summarizing conversations.", true}, + {"交互式 CLI 模板", "You are an interactive CLI tool that helps users", true}, + {"无关文本", "Write me a poem about cats", false}, + {"空文本", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{"type": "text", "text": tt.prompt}, + }, + } + result := v.IncludesClaudeCodeSystemPrompt(body) + require.Equal(t, tt.want, result, "提示词: %q", tt.prompt) + }) + } +} + +func TestDiceCoefficient(t *testing.T) { + tests := []struct { + name string + a string + b string + want float64 + tol float64 + }{ + {"相同字符串", "hello", "hello", 1.0, 0.001}, + {"完全不同", "abc", "xyz", 0.0, 0.001}, + {"空字符串", "", "hello", 0.0, 0.001}, + {"单字符", "a", "b", 0.0, 0.001}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := diceCoefficient(tt.a, tt.b) + require.InDelta(t, tt.want, result, tt.tol) + }) + } +} + +func TestIsClaudeCodeClient_Context(t *testing.T) { + ctx := context.Background() + + // 默认应为 false + require.False(t, IsClaudeCodeClient(ctx)) + + // 设置为 true + ctx = SetClaudeCodeClient(ctx, true) + require.True(t, IsClaudeCodeClient(ctx)) + + // 设置为 false + ctx = SetClaudeCodeClient(ctx, false) + require.False(t, IsClaudeCodeClient(ctx)) +} + +func TestValidate_NilBody_MessagesPath(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + result := v.Validate(req, nil) + require.False(t, result, "nil body 的 messages 请求应返回 false") +} diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go index 6d06c83e..f71098b1 100644 --- a/backend/internal/service/claude_code_validator.go +++ b/backend/internal/service/claude_code_validator.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "regexp" + "strconv" "strings" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" @@ -17,6 +18,9 @@ var ( // User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感) claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`) + // 带捕获组的版本提取正则 + claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`) + // metadata.user_id 格式: user_{64位hex}_account__session_{uuid} userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`) @@ -78,7 +82,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo // Step 3: 检查 max_tokens=1 + haiku 探测请求绕过 // 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt - if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku { + if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku { return true // 绕过 system prompt 检查,UA 已在 Step 1 验证 } @@ -270,3 +274,55 @@ func IsClaudeCodeClient(ctx context.Context) bool { func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context { return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode) } + +// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号 +// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串 +func (v *ClaudeCodeValidator) ExtractVersion(ua string) string { + matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua) + if len(matches) >= 2 { + return matches[1] + } + return "" +} + +// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中 +func SetClaudeCodeVersion(ctx context.Context, version string) context.Context { + return context.WithValue(ctx, ctxkey.ClaudeCodeVersion, version) +} + +// GetClaudeCodeVersion 从 context 中获取 Claude Code 版本号 +func GetClaudeCodeVersion(ctx context.Context) string { + if v, ok := ctx.Value(ctxkey.ClaudeCodeVersion).(string); ok { + return v + } + return "" +} + +// CompareVersions 比较两个 semver 版本号 +// 返回: -1 (a < b), 0 (a == b), 1 (a > b) +func CompareVersions(a, b string) int { + aParts := parseSemver(a) + bParts := parseSemver(b) + for i := 0; i < 3; i++ { + if aParts[i] < bParts[i] { + return -1 + } + if aParts[i] > bParts[i] { + return 1 + } + } + return 0 +} + +// parseSemver 解析 semver 版本号为 [major, minor, patch] +func parseSemver(v string) [3]int { + v = strings.TrimPrefix(v, "v") + parts := strings.Split(v, ".") + result := [3]int{0, 0, 0} + for i := 0; i < len(parts) && i < 3; i++ { + if parsed, err := strconv.Atoi(parts[i]); err == nil { + result[i] = parsed + } + } + return result +} diff --git a/backend/internal/service/claude_code_validator_test.go b/backend/internal/service/claude_code_validator_test.go index a4cd1886..f87c56e8 100644 --- a/backend/internal/service/claude_code_validator_test.go +++ b/backend/internal/service/claude_code_validator_test.go @@ -56,3 +56,51 @@ func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) { ok := validator.Validate(req, nil) require.True(t, ok) } + +func TestExtractVersion(t *testing.T) { + v := NewClaudeCodeValidator() + tests := []struct { + ua string + want string + }{ + {"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"}, + {"claude-cli/1.0.0", "1.0.0"}, + {"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感 + {"curl/8.0.0", ""}, // 非 Claude CLI + {"", ""}, // 空字符串 + {"claude-cli/", ""}, // 无版本号 + {"claude-cli/2.1.22-beta", "2.1.22"}, // 带后缀仍提取主版本号 + } + for _, tt := range tests { + got := v.ExtractVersion(tt.ua) + require.Equal(t, tt.want, got, "ExtractVersion(%q)", tt.ua) + } +} + +func TestCompareVersions(t *testing.T) { + tests := []struct { + a, b string + want int + }{ + {"2.1.0", "2.1.0", 0}, // 相等 + {"2.1.1", "2.1.0", 1}, // patch 更大 + {"2.0.0", "2.1.0", -1}, // minor 更小 + {"3.0.0", "2.99.99", 1}, // major 更大 + {"1.0.0", "2.0.0", -1}, // major 更小 + {"0.0.1", "0.0.0", 1}, // patch 差异 + {"", "1.0.0", -1}, // 空字符串 vs 正常版本 + {"v2.1.0", "2.1.0", 0}, // v 前缀处理 + } + for _, tt := range tests { + got := CompareVersions(tt.a, tt.b) + require.Equal(t, tt.want, got, "CompareVersions(%q, %q)", tt.a, tt.b) + } +} + +func TestSetGetClaudeCodeVersion(t *testing.T) { + ctx := context.Background() + require.Equal(t, "", GetClaudeCodeVersion(ctx), "empty context should return empty string") + + ctx = SetClaudeCodeVersion(ctx, "2.1.63") + require.Equal(t, "2.1.63", GetClaudeCodeVersion(ctx)) +} diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index d5cb2025..4dcf84e0 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -3,10 +3,13 @@ package service import ( "context" "crypto/rand" - "encoding/hex" - "fmt" - "log" + "encoding/binary" + "os" + "strconv" + "sync/atomic" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // ConcurrencyCache 定义并发控制的缓存接口 @@ -17,6 +20,7 @@ type ConcurrencyCache interface { AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) + GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) // 账号等待队列(账号级) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) @@ -41,15 +45,25 @@ type ConcurrencyCache interface { CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error } -// generateRequestID generates a unique request ID for concurrency slot tracking -// Uses 8 random bytes (16 hex chars) for uniqueness -func generateRequestID() string { +var ( + requestIDPrefix = initRequestIDPrefix() + requestIDCounter atomic.Uint64 +) + +func initRequestIDPrefix() string { b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { - // Fallback to nanosecond timestamp (extremely rare case) - return fmt.Sprintf("%x", time.Now().UnixNano()) + if _, err := rand.Read(b); err == nil { + return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36) } - return hex.EncodeToString(b) + fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16) + return "r" + strconv.FormatUint(fallback, 36) +} + +// generateRequestID generates a unique request ID for concurrency slot tracking. +// Format: {process_random_prefix}-{base36_counter} +func generateRequestID() string { + seq := requestIDCounter.Add(1) + return requestIDPrefix + "-" + strconv.FormatUint(seq, 36) } const ( @@ -124,7 +138,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil { - log.Printf("Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err) + logger.LegacyPrintf("service.concurrency", "Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err) } }, }, nil @@ -163,7 +177,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil { - log.Printf("Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err) + logger.LegacyPrintf("service.concurrency", "Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err) } }, }, nil @@ -191,7 +205,7 @@ func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int6 result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait) if err != nil { // On error, allow the request to proceed (fail open) - log.Printf("Warning: increment wait count failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for user %d: %v", userID, err) return true, nil } return result, nil @@ -209,7 +223,7 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6 defer cancel() if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil { - log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for user %d: %v", userID, err) } } @@ -221,7 +235,7 @@ func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, acco result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait) if err != nil { - log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err) + logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for account %d: %v", accountID, err) return true, nil } return result, nil @@ -237,7 +251,7 @@ func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, acco defer cancel() if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil { - log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err) + logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for account %d: %v", accountID, err) } } @@ -293,7 +307,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor accounts, err := accountRepo.ListSchedulable(listCtx) cancel() if err != nil { - log.Printf("Warning: list schedulable accounts failed: %v", err) + logger.LegacyPrintf("service.concurrency", "Warning: list schedulable accounts failed: %v", err) return } for _, account := range accounts { @@ -301,7 +315,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID) accountCancel() if err != nil { - log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err) + logger.LegacyPrintf("service.concurrency", "Warning: cleanup expired slots failed for account %d: %v", account.ID, err) } } } @@ -320,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // Returns a map of accountID -> current concurrency count func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { - result := make(map[int64]int) - - for _, accountID := range accountIDs { - count, err := s.cache.GetAccountConcurrency(ctx, accountID) - if err != nil { - // If key doesn't exist in Redis, count is 0 - count = 0 - } - result[accountID] = count + if len(accountIDs) == 0 { + return map[int64]int{}, nil } - - return result, nil + if s.cache == nil { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + result[accountID] = 0 + } + return result, nil + } + return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs) } diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go new file mode 100644 index 00000000..9ba43d93 --- /dev/null +++ b/backend/internal/service/concurrency_service_test.go @@ -0,0 +1,311 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩 +type stubConcurrencyCacheForTest struct { + acquireResult bool + acquireErr error + releaseErr error + concurrency int + concurrencyErr error + waitAllowed bool + waitErr error + waitCount int + waitCountErr error + loadBatch map[int64]*AccountLoadInfo + loadBatchErr error + usersLoadBatch map[int64]*UserLoadInfo + usersLoadErr error + cleanupErr error + + // 记录调用 + releasedAccountIDs []int64 + releasedRequestIDs []string +} + +var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil) + +func (c *stubConcurrencyCacheForTest) AcquireAccountSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return c.acquireResult, c.acquireErr +} +func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, accountID int64, requestID string) error { + c.releasedAccountIDs = append(c.releasedAccountIDs, accountID) + c.releasedRequestIDs = append(c.releasedRequestIDs, requestID) + return c.releaseErr +} +func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) { + return c.concurrency, c.concurrencyErr +} +func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + if c.concurrencyErr != nil { + return nil, c.concurrencyErr + } + result[accountID] = c.concurrency + } + return result, nil +} +func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return c.waitAllowed, c.waitErr +} +func (c *stubConcurrencyCacheForTest) DecrementAccountWaitCount(_ context.Context, _ int64) error { + return nil +} +func (c *stubConcurrencyCacheForTest) GetAccountWaitingCount(_ context.Context, _ int64) (int, error) { + return c.waitCount, c.waitCountErr +} +func (c *stubConcurrencyCacheForTest) AcquireUserSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return c.acquireResult, c.acquireErr +} +func (c *stubConcurrencyCacheForTest) ReleaseUserSlot(_ context.Context, _ int64, _ string) error { + return c.releaseErr +} +func (c *stubConcurrencyCacheForTest) GetUserConcurrency(_ context.Context, _ int64) (int, error) { + return c.concurrency, c.concurrencyErr +} +func (c *stubConcurrencyCacheForTest) IncrementWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return c.waitAllowed, c.waitErr +} +func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ int64) error { + return nil +} +func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + return c.loadBatch, c.loadBatchErr +} +func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { + return c.usersLoadBatch, c.usersLoadErr +} +func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { + return c.cleanupErr +} + +func TestAcquireAccountSlot_Success(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.NoError(t, err) + require.True(t, result.Acquired) + require.NotNil(t, result.ReleaseFunc) +} + +func TestAcquireAccountSlot_Failure(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: false} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.NoError(t, err) + require.False(t, result.Acquired) + require.Nil(t, result.ReleaseFunc) +} + +func TestAcquireAccountSlot_UnlimitedConcurrency(t *testing.T) { + svc := NewConcurrencyService(&stubConcurrencyCacheForTest{}) + + for _, maxConcurrency := range []int{0, -1} { + result, err := svc.AcquireAccountSlot(context.Background(), 1, maxConcurrency) + require.NoError(t, err) + require.True(t, result.Acquired, "maxConcurrency=%d 应无限制通过", maxConcurrency) + require.NotNil(t, result.ReleaseFunc, "ReleaseFunc 应为 no-op 函数") + } +} + +func TestAcquireAccountSlot_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireErr: errors.New("redis down")} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.Error(t, err) + require.Nil(t, result) +} + +func TestAcquireAccountSlot_ReleaseDecrements(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 42, 5) + require.NoError(t, err) + require.True(t, result.Acquired) + + // 调用 ReleaseFunc 应释放槽位 + result.ReleaseFunc() + + require.Len(t, cache.releasedAccountIDs, 1) + require.Equal(t, int64(42), cache.releasedAccountIDs[0]) + require.Len(t, cache.releasedRequestIDs, 1) + require.NotEmpty(t, cache.releasedRequestIDs[0], "requestID 不应为空") +} + +func TestAcquireUserSlot_IndependentFromAccount(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + // 用户槽位获取应独立于账户槽位 + result, err := svc.AcquireUserSlot(context.Background(), 100, 3) + require.NoError(t, err) + require.True(t, result.Acquired) + require.NotNil(t, result.ReleaseFunc) +} + +func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) { + svc := NewConcurrencyService(&stubConcurrencyCacheForTest{}) + + result, err := svc.AcquireUserSlot(context.Background(), 1, 0) + require.NoError(t, err) + require.True(t, result.Acquired) +} + +func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) { + id1 := generateRequestID() + id2 := generateRequestID() + require.NotEmpty(t, id1) + require.NotEmpty(t, id2) + + p1 := strings.Split(id1, "-") + p2 := strings.Split(id2, "-") + require.Len(t, p1, 2) + require.Len(t, p2, 2) + require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致") + + n1, err := strconv.ParseUint(p1[1], 36, 64) + require.NoError(t, err) + n2, err := strconv.ParseUint(p2[1], 36, 64) + require.NoError(t, err) + require.Equal(t, n1+1, n2, "计数器应单调递增") +} + +func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) { + expected := map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60}, + 2: {AccountID: 2, CurrentConcurrency: 5, WaitingCount: 2, LoadRate: 100}, + } + cache := &stubConcurrencyCacheForTest{loadBatch: expected} + svc := NewConcurrencyService(cache) + + accounts := []AccountWithConcurrency{ + {ID: 1, MaxConcurrency: 5}, + {ID: 2, MaxConcurrency: 5}, + } + result, err := svc.GetAccountsLoadBatch(context.Background(), accounts) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func TestGetAccountsLoadBatch_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + result, err := svc.GetAccountsLoadBatch(context.Background(), nil) + require.NoError(t, err) + require.Empty(t, result) +} + +func TestIncrementWaitCount_Success(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed) +} + +func TestIncrementWaitCount_QueueFull(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: false} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.False(t, allowed) +} + +func TestIncrementWaitCount_FailOpen(t *testing.T) { + // Redis 错误时应 fail-open(允许请求通过) + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis timeout")} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err, "Redis 错误不应传播") + require.True(t, allowed, "Redis 错误时应 fail-open") +} + +func TestIncrementWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed, "nil cache 应 fail-open") +} + +func TestCalculateMaxWait(t *testing.T) { + tests := []struct { + concurrency int + expected int + }{ + {5, 25}, // 5 + 20 + {1, 21}, // 1 + 20 + {0, 21}, // min(1) + 20 + {-1, 21}, // min(1) + 20 + {10, 30}, // 10 + 20 + } + for _, tt := range tests { + result := CalculateMaxWait(tt.concurrency) + require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency) + } +} + +func TestGetAccountWaitingCount(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitCount: 5} + svc := NewConcurrencyService(cache) + + count, err := svc.GetAccountWaitingCount(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, 5, count) +} + +func TestGetAccountWaitingCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + count, err := svc.GetAccountWaitingCount(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, 0, count) +} + +func TestGetAccountConcurrencyBatch(t *testing.T) { + cache := &stubConcurrencyCacheForTest{concurrency: 3} + svc := NewConcurrencyService(cache) + + result, err := svc.GetAccountConcurrencyBatch(context.Background(), []int64{1, 2, 3}) + require.NoError(t, err) + require.Len(t, result, 3) + for _, id := range []int64{1, 2, 3} { + require.Equal(t, 3, result[id]) + } +} + +func TestIncrementAccountWaitCount_FailOpen(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis error")} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err, "Redis 错误不应传播") + require.True(t, allowed, "Redis 错误时应 fail-open") +} + +func TestIncrementAccountWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err) + require.True(t, allowed) +} diff --git a/backend/internal/service/crs_sync_helpers_test.go b/backend/internal/service/crs_sync_helpers_test.go new file mode 100644 index 00000000..0dc05335 --- /dev/null +++ b/backend/internal/service/crs_sync_helpers_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "testing" +) + +func TestBuildSelectedSet(t *testing.T) { + tests := []struct { + name string + ids []string + wantNil bool + wantSize int + }{ + { + name: "nil input returns nil (backward compatible: create all)", + ids: nil, + wantNil: true, + }, + { + name: "empty slice returns empty map (create none)", + ids: []string{}, + wantNil: false, + wantSize: 0, + }, + { + name: "single ID", + ids: []string{"abc-123"}, + wantNil: false, + wantSize: 1, + }, + { + name: "multiple IDs", + ids: []string{"a", "b", "c"}, + wantNil: false, + wantSize: 3, + }, + { + name: "duplicate IDs are deduplicated", + ids: []string{"a", "a", "b"}, + wantNil: false, + wantSize: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildSelectedSet(tt.ids) + if tt.wantNil { + if got != nil { + t.Errorf("buildSelectedSet(%v) = %v, want nil", tt.ids, got) + } + return + } + if got == nil { + t.Fatalf("buildSelectedSet(%v) = nil, want non-nil map", tt.ids) + } + if len(got) != tt.wantSize { + t.Errorf("buildSelectedSet(%v) has %d entries, want %d", tt.ids, len(got), tt.wantSize) + } + // Verify all unique IDs are present + for _, id := range tt.ids { + if _, ok := got[id]; !ok { + t.Errorf("buildSelectedSet(%v) missing key %q", tt.ids, id) + } + } + }) + } +} + +func TestShouldCreateAccount(t *testing.T) { + tests := []struct { + name string + crsID string + selectedSet map[string]struct{} + want bool + }{ + { + name: "nil set allows all (backward compatible)", + crsID: "any-id", + selectedSet: nil, + want: true, + }, + { + name: "empty set blocks all", + crsID: "any-id", + selectedSet: map[string]struct{}{}, + want: false, + }, + { + name: "ID in set is allowed", + crsID: "abc-123", + selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}}, + want: true, + }, + { + name: "ID not in set is blocked", + crsID: "xyz-789", + selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := shouldCreateAccount(tt.crsID, tt.selectedSet) + if got != tt.want { + t.Errorf("shouldCreateAccount(%q, %v) = %v, want %v", + tt.crsID, tt.selectedSet, got, tt.want) + } + }) + } +} diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index a6ccb967..6a916740 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -45,10 +45,11 @@ func NewCRSSyncService( } type SyncFromCRSInput struct { - BaseURL string - Username string - Password string - SyncProxies bool + BaseURL string + Username string + Password string + SyncProxies bool + SelectedAccountIDs []string // if non-empty, only create new accounts with these CRS IDs } type SyncFromCRSItemResult struct { @@ -190,25 +191,27 @@ type crsGeminiAPIKeyAccount struct { Extra map[string]any `json:"extra"` } -func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) { +// fetchCRSExport validates the connection parameters, authenticates with CRS, +// and returns the exported accounts. Shared by SyncFromCRS and PreviewFromCRS. +func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, password string) (*crsExportResponse, error) { if s.cfg == nil { return nil, errors.New("config is not available") } - baseURL := strings.TrimSpace(input.BaseURL) + normalizedURL := strings.TrimSpace(baseURL) if s.cfg.Security.URLAllowlist.Enabled { - normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts) + normalized, err := normalizeBaseURL(normalizedURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts) if err != nil { return nil, err } - baseURL = normalized + normalizedURL = normalized } else { - normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) + normalized, err := urlvalidator.ValidateURLFormat(normalizedURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) if err != nil { return nil, fmt.Errorf("invalid base_url: %w", err) } - baseURL = normalized + normalizedURL = normalized } - if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" { + if strings.TrimSpace(username) == "" || strings.TrimSpace(password) == "" { return nil, errors.New("username and password are required") } @@ -218,15 +221,19 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts, }) if err != nil { - client = &http.Client{Timeout: 20 * time.Second} + return nil, fmt.Errorf("create http client failed: %w", err) } - adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password) + adminToken, err := crsLogin(ctx, client, normalizedURL, username, password) if err != nil { return nil, err } - exported, err := crsExportAccounts(ctx, client, baseURL, adminToken) + return crsExportAccounts(ctx, client, normalizedURL, adminToken) +} + +func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) { + exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password) if err != nil { return nil, err } @@ -241,6 +248,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ), } + selectedSet := buildSelectedSet(input.SelectedAccountIDs) + var proxies []Proxy if input.SyncProxies { proxies, _ = s.proxyRepo.ListActive(ctx) @@ -329,6 +338,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { + if !shouldCreateAccount(src.ID, selectedSet) { + item.Action = "skipped" + item.Error = "not selected" + result.Skipped++ + result.Items = append(result.Items, item) + continue + } account := &Account{ Name: defaultName(src.Name, src.ID), Platform: PlatformAnthropic, @@ -446,6 +462,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { + if !shouldCreateAccount(src.ID, selectedSet) { + item.Action = "skipped" + item.Error = "not selected" + result.Skipped++ + result.Items = append(result.Items, item) + continue + } account := &Account{ Name: defaultName(src.Name, src.ID), Platform: PlatformAnthropic, @@ -569,6 +592,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { + if !shouldCreateAccount(src.ID, selectedSet) { + item.Action = "skipped" + item.Error = "not selected" + result.Skipped++ + result.Items = append(result.Items, item) + continue + } account := &Account{ Name: defaultName(src.Name, src.ID), Platform: PlatformOpenAI, @@ -690,6 +720,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { + if !shouldCreateAccount(src.ID, selectedSet) { + item.Action = "skipped" + item.Error = "not selected" + result.Skipped++ + result.Items = append(result.Items, item) + continue + } account := &Account{ Name: defaultName(src.Name, src.ID), Platform: PlatformOpenAI, @@ -798,6 +835,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { + if !shouldCreateAccount(src.ID, selectedSet) { + item.Action = "skipped" + item.Error = "not selected" + result.Skipped++ + result.Items = append(result.Items, item) + continue + } account := &Account{ Name: defaultName(src.Name, src.ID), Platform: PlatformGemini, @@ -909,6 +953,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { + if !shouldCreateAccount(src.ID, selectedSet) { + item.Action = "skipped" + item.Error = "not selected" + result.Skipped++ + result.Items = append(result.Items, item) + continue + } account := &Account{ Name: defaultName(src.Name, src.ID), Platform: PlatformGemini, @@ -1253,3 +1304,102 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account return newCredentials } + +// buildSelectedSet converts a slice of selected CRS account IDs to a set for O(1) lookup. +// Returns nil if ids is nil (field not sent → backward compatible: create all). +// Returns an empty map if ids is non-nil but empty (user selected none → create none). +func buildSelectedSet(ids []string) map[string]struct{} { + if ids == nil { + return nil + } + set := make(map[string]struct{}, len(ids)) + for _, id := range ids { + set[id] = struct{}{} + } + return set +} + +// shouldCreateAccount checks if a new CRS account should be created based on user selection. +// Returns true if selectedSet is nil (backward compatible: create all) or if crsID is in the set. +func shouldCreateAccount(crsID string, selectedSet map[string]struct{}) bool { + if selectedSet == nil { + return true + } + _, ok := selectedSet[crsID] + return ok +} + +// PreviewFromCRSResult contains the preview of accounts from CRS before sync. +type PreviewFromCRSResult struct { + NewAccounts []CRSPreviewAccount `json:"new_accounts"` + ExistingAccounts []CRSPreviewAccount `json:"existing_accounts"` +} + +// CRSPreviewAccount represents a single account in the preview result. +type CRSPreviewAccount struct { + CRSAccountID string `json:"crs_account_id"` + Kind string `json:"kind"` + Name string `json:"name"` + Platform string `json:"platform"` + Type string `json:"type"` +} + +// PreviewFromCRS connects to CRS, fetches all accounts, and classifies them +// as new or existing by batch-querying local crs_account_id mappings. +func (s *CRSSyncService) PreviewFromCRS(ctx context.Context, input SyncFromCRSInput) (*PreviewFromCRSResult, error) { + exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password) + if err != nil { + return nil, err + } + + // Batch query all existing CRS account IDs + existingCRSIDs, err := s.accountRepo.ListCRSAccountIDs(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list existing CRS accounts: %w", err) + } + + result := &PreviewFromCRSResult{ + NewAccounts: make([]CRSPreviewAccount, 0), + ExistingAccounts: make([]CRSPreviewAccount, 0), + } + + classify := func(crsID, kind, name, platform, accountType string) { + preview := CRSPreviewAccount{ + CRSAccountID: crsID, + Kind: kind, + Name: defaultName(name, crsID), + Platform: platform, + Type: accountType, + } + if _, exists := existingCRSIDs[crsID]; exists { + result.ExistingAccounts = append(result.ExistingAccounts, preview) + } else { + result.NewAccounts = append(result.NewAccounts, preview) + } + } + + for _, src := range exported.Data.ClaudeAccounts { + authType := strings.TrimSpace(src.AuthType) + if authType == "" { + authType = AccountTypeOAuth + } + classify(src.ID, src.Kind, src.Name, PlatformAnthropic, authType) + } + for _, src := range exported.Data.ClaudeConsoleAccounts { + classify(src.ID, src.Kind, src.Name, PlatformAnthropic, AccountTypeAPIKey) + } + for _, src := range exported.Data.OpenAIOAuthAccounts { + classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeOAuth) + } + for _, src := range exported.Data.OpenAIResponsesAccounts { + classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeAPIKey) + } + for _, src := range exported.Data.GeminiOAuthAccounts { + classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeOAuth) + } + for _, src := range exported.Data.GeminiAPIKeyAccounts { + classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeAPIKey) + } + + return result, nil +} diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go index 10c68868..a67f8532 100644 --- a/backend/internal/service/dashboard_aggregation_service.go +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -3,11 +3,12 @@ package service import ( "context" "errors" - "log" + "log/slog" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) const ( @@ -65,7 +66,7 @@ func (s *DashboardAggregationService) Start() { return } if !s.cfg.Enabled { - log.Printf("[DashboardAggregation] 聚合作业已禁用") + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业已禁用") return } @@ -81,9 +82,9 @@ func (s *DashboardAggregationService) Start() { s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() { s.runScheduledAggregation() }) - log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds) if !s.cfg.BackfillEnabled { - log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填") + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填") } } @@ -93,7 +94,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro return errors.New("聚合服务未初始化") } if !s.cfg.BackfillEnabled { - log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false") + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填被拒绝: backfill_enabled=false") return ErrDashboardBackfillDisabled } if !end.After(start) { @@ -110,7 +111,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) defer cancel() if err := s.backfillRange(ctx, start, end); err != nil { - log.Printf("[DashboardAggregation] 回填失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填失败: %v", err) } }() return nil @@ -141,12 +142,12 @@ func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time return } if !errors.Is(err, errDashboardAggregationRunning) { - log.Printf("[DashboardAggregation] 重新计算失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算失败: %v", err) return } time.Sleep(5 * time.Second) } - log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用") + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算放弃: 聚合作业持续占用") }() return nil } @@ -162,7 +163,7 @@ func (s *DashboardAggregationService) recomputeRecentDays() { ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) defer cancel() if err := s.backfillRange(ctx, start, now); err != nil { - log.Printf("[DashboardAggregation] 启动重算失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 启动重算失败: %v", err) return } } @@ -177,7 +178,7 @@ func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, if err := s.repo.RecomputeRange(ctx, start, end); err != nil { return err } - log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)", + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)", start.UTC().Format(time.RFC3339), end.UTC().Format(time.RFC3339), time.Since(jobStart).String(), @@ -198,7 +199,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() { now := time.Now().UTC() last, err := s.repo.GetAggregationWatermark(ctx) if err != nil { - log.Printf("[DashboardAggregation] 读取水位失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 读取水位失败: %v", err) last = time.Unix(0, 0).UTC() } @@ -216,19 +217,19 @@ func (s *DashboardAggregationService) runScheduledAggregation() { } if err := s.aggregateRange(ctx, start, now); err != nil { - log.Printf("[DashboardAggregation] 聚合失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合失败: %v", err) return } updateErr := s.repo.UpdateAggregationWatermark(ctx, now) if updateErr != nil { - log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr) } - log.Printf("[DashboardAggregation] 聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)", - start.Format(time.RFC3339), - now.Format(time.RFC3339), - time.Since(jobStart).String(), - updateErr == nil, + slog.Debug("[DashboardAggregation] 聚合完成", + "start", start.Format(time.RFC3339), + "end", now.Format(time.RFC3339), + "duration", time.Since(jobStart).String(), + "watermark_updated", updateErr == nil, ) s.maybeCleanupRetention(ctx, now) @@ -261,9 +262,9 @@ func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, updateErr := s.repo.UpdateAggregationWatermark(ctx, endUTC) if updateErr != nil { - log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr) } - log.Printf("[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)", + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)", startUTC.Format(time.RFC3339), endUTC.Format(time.RFC3339), time.Since(jobStart).String(), @@ -279,7 +280,7 @@ func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, return nil } if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil { - log.Printf("[DashboardAggregation] 分区检查失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 分区检查失败: %v", err) } return s.repo.AggregateRange(ctx, start, end) } @@ -298,11 +299,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff) if aggErr != nil { - log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合保留清理失败: %v", aggErr) } usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff) if usageErr != nil { - log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr) } if aggErr == nil && usageErr == nil { s.lastRetentionCleanup.Store(now) diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index cd11923e..2af43386 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -5,11 +5,11 @@ import ( "encoding/json" "errors" "fmt" - "log" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" ) @@ -113,7 +113,7 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D return cached, nil } if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) { - log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存读取失败: %v", err) } } @@ -124,22 +124,30 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D return stats, nil } -func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { - trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) +func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { + trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) if err != nil { return nil, fmt.Errorf("get usage trend with filters: %w", err) } return trend, nil } -func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) +func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { return nil, fmt.Errorf("get model stats with filters: %w", err) } return stats, nil } +func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { + stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + if err != nil { + return nil, fmt.Errorf("get group stats with filters: %w", err) + } + return stats, nil +} + func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) { data, err := s.cache.GetDashboardStats(ctx) if err != nil { @@ -188,7 +196,7 @@ func (s *DashboardService) refreshDashboardStatsAsync() { stats, err := s.fetchDashboardStats(ctx) if err != nil { - log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异步刷新失败: %v", err) return } s.applyAggregationStatus(ctx, stats) @@ -220,12 +228,12 @@ func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *u } data, err := json.Marshal(entry) if err != nil { - log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存序列化失败: %v", err) return } if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil { - log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存写入失败: %v", err) } } @@ -237,10 +245,10 @@ func (s *DashboardService) evictDashboardStatsCache(reason error) { defer cancel() if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil { - log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存清理失败: %v", err) } if reason != nil { - log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异常,已清理: %v", reason) } } @@ -271,7 +279,7 @@ func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.T } updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx) if err != nil { - log.Printf("[Dashboard] 读取聚合水位失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 读取聚合水位失败: %v", err) return time.Unix(0, 0).UTC() } if updatedAt.IsZero() { @@ -319,16 +327,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end return trend, nil } -func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { - stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs) +func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { + stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime) if err != nil { return nil, fmt.Errorf("get batch user usage stats: %w", err) } return stats, nil } -func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { - stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs) +func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { + stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime) if err != nil { return nil, fmt.Errorf("get batch api key usage stats: %w", err) } diff --git a/backend/internal/service/data_management_grpc.go b/backend/internal/service/data_management_grpc.go new file mode 100644 index 00000000..aeb3d529 --- /dev/null +++ b/backend/internal/service/data_management_grpc.go @@ -0,0 +1,252 @@ +package service + +import "context" + +type DataManagementPostgresConfig struct { + Host string `json:"host"` + Port int32 `json:"port"` + User string `json:"user"` + Password string `json:"password,omitempty"` + PasswordConfigured bool `json:"password_configured"` + Database string `json:"database"` + SSLMode string `json:"ssl_mode"` + ContainerName string `json:"container_name"` +} + +type DataManagementRedisConfig struct { + Addr string `json:"addr"` + Username string `json:"username"` + Password string `json:"password,omitempty"` + PasswordConfigured bool `json:"password_configured"` + DB int32 `json:"db"` + ContainerName string `json:"container_name"` +} + +type DataManagementS3Config struct { + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key,omitempty"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` +} + +type DataManagementConfig struct { + SourceMode string `json:"source_mode"` + BackupRoot string `json:"backup_root"` + SQLitePath string `json:"sqlite_path,omitempty"` + RetentionDays int32 `json:"retention_days"` + KeepLast int32 `json:"keep_last"` + ActivePostgresID string `json:"active_postgres_profile_id"` + ActiveRedisID string `json:"active_redis_profile_id"` + Postgres DataManagementPostgresConfig `json:"postgres"` + Redis DataManagementRedisConfig `json:"redis"` + S3 DataManagementS3Config `json:"s3"` + ActiveS3ProfileID string `json:"active_s3_profile_id"` +} + +type DataManagementTestS3Result struct { + OK bool `json:"ok"` + Message string `json:"message"` +} + +type DataManagementCreateBackupJobInput struct { + BackupType string + UploadToS3 bool + TriggeredBy string + IdempotencyKey string + S3ProfileID string + PostgresID string + RedisID string +} + +type DataManagementListBackupJobsInput struct { + PageSize int32 + PageToken string + Status string + BackupType string +} + +type DataManagementArtifactInfo struct { + LocalPath string `json:"local_path"` + SizeBytes int64 `json:"size_bytes"` + SHA256 string `json:"sha256"` +} + +type DataManagementS3ObjectInfo struct { + Bucket string `json:"bucket"` + Key string `json:"key"` + ETag string `json:"etag"` +} + +type DataManagementBackupJob struct { + JobID string `json:"job_id"` + BackupType string `json:"backup_type"` + Status string `json:"status"` + TriggeredBy string `json:"triggered_by"` + IdempotencyKey string `json:"idempotency_key,omitempty"` + UploadToS3 bool `json:"upload_to_s3"` + S3ProfileID string `json:"s3_profile_id,omitempty"` + PostgresID string `json:"postgres_profile_id,omitempty"` + RedisID string `json:"redis_profile_id,omitempty"` + StartedAt string `json:"started_at,omitempty"` + FinishedAt string `json:"finished_at,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + Artifact DataManagementArtifactInfo `json:"artifact"` + S3Object DataManagementS3ObjectInfo `json:"s3"` +} + +type DataManagementSourceProfile struct { + SourceType string `json:"source_type"` + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + Config DataManagementSourceConfig `json:"config"` + PasswordConfigured bool `json:"password_configured"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type DataManagementSourceConfig struct { + Host string `json:"host"` + Port int32 `json:"port"` + User string `json:"user"` + Password string `json:"password,omitempty"` + Database string `json:"database"` + SSLMode string `json:"ssl_mode"` + Addr string `json:"addr"` + Username string `json:"username"` + DB int32 `json:"db"` + ContainerName string `json:"container_name"` +} + +type DataManagementCreateSourceProfileInput struct { + SourceType string + ProfileID string + Name string + Config DataManagementSourceConfig + SetActive bool +} + +type DataManagementUpdateSourceProfileInput struct { + SourceType string + ProfileID string + Name string + Config DataManagementSourceConfig +} + +type DataManagementS3Profile struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + S3 DataManagementS3Config `json:"s3"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type DataManagementCreateS3ProfileInput struct { + ProfileID string + Name string + S3 DataManagementS3Config + SetActive bool +} + +type DataManagementUpdateS3ProfileInput struct { + ProfileID string + Name string + S3 DataManagementS3Config +} + +type DataManagementListBackupJobsResult struct { + Items []DataManagementBackupJob `json:"items"` + NextPageToken string `json:"next_page_token,omitempty"` +} + +func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) { + _ = ctx + return DataManagementConfig{}, s.deprecatedError() +} + +func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) { + _, _ = ctx, cfg + return DataManagementConfig{}, s.deprecatedError() +} + +func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) { + _, _ = ctx, sourceType + return nil, s.deprecatedError() +} + +func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) { + _, _ = ctx, input + return DataManagementSourceProfile{}, s.deprecatedError() +} + +func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) { + _, _ = ctx, input + return DataManagementSourceProfile{}, s.deprecatedError() +} + +func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error { + _, _, _ = ctx, sourceType, profileID + return s.deprecatedError() +} + +func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) { + _, _, _ = ctx, sourceType, profileID + return DataManagementSourceProfile{}, s.deprecatedError() +} + +func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) { + _, _ = ctx, cfg + return DataManagementTestS3Result{}, s.deprecatedError() +} + +func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) { + _ = ctx + return nil, s.deprecatedError() +} + +func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) { + _, _ = ctx, input + return DataManagementS3Profile{}, s.deprecatedError() +} + +func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) { + _, _ = ctx, input + return DataManagementS3Profile{}, s.deprecatedError() +} + +func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error { + _, _ = ctx, profileID + return s.deprecatedError() +} + +func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) { + _, _ = ctx, profileID + return DataManagementS3Profile{}, s.deprecatedError() +} + +func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) { + _, _ = ctx, input + return DataManagementBackupJob{}, s.deprecatedError() +} + +func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) { + _, _ = ctx, input + return DataManagementListBackupJobsResult{}, s.deprecatedError() +} + +func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) { + _, _ = ctx, jobID + return DataManagementBackupJob{}, s.deprecatedError() +} + +func (s *DataManagementService) deprecatedError() error { + return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()}) +} diff --git a/backend/internal/service/data_management_grpc_test.go b/backend/internal/service/data_management_grpc_test.go new file mode 100644 index 00000000..286eb58d --- /dev/null +++ b/backend/internal/service/data_management_grpc_test.go @@ -0,0 +1,36 @@ +package service + +import ( + "context" + "path/filepath" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestDataManagementService_DeprecatedRPCMethods(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "datamanagement.sock") + svc := NewDataManagementServiceWithOptions(socketPath, 0) + + _, err := svc.GetConfig(context.Background()) + assertDeprecatedDataManagementError(t, err, socketPath) + + _, err = svc.CreateBackupJob(context.Background(), DataManagementCreateBackupJobInput{BackupType: "full"}) + assertDeprecatedDataManagementError(t, err, socketPath) + + err = svc.DeleteS3Profile(context.Background(), "s3-default") + assertDeprecatedDataManagementError(t, err, socketPath) +} + +func assertDeprecatedDataManagementError(t *testing.T, err error, socketPath string) { + t.Helper() + + require.Error(t, err) + statusCode, status := infraerrors.ToHTTP(err) + require.Equal(t, 503, statusCode) + require.Equal(t, DataManagementDeprecatedReason, status.Reason) + require.Equal(t, socketPath, status.Metadata["socket_path"]) +} diff --git a/backend/internal/service/data_management_service.go b/backend/internal/service/data_management_service.go new file mode 100644 index 00000000..b525c0fa --- /dev/null +++ b/backend/internal/service/data_management_service.go @@ -0,0 +1,95 @@ +package service + +import ( + "context" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +const ( + DefaultDataManagementAgentSocketPath = "/tmp/sub2api-datamanagement.sock" + LegacyBackupAgentSocketPath = "/tmp/sub2api-backup.sock" + + DataManagementDeprecatedReason = "DATA_MANAGEMENT_DEPRECATED" + DataManagementAgentSocketMissingReason = "DATA_MANAGEMENT_AGENT_SOCKET_MISSING" + DataManagementAgentUnavailableReason = "DATA_MANAGEMENT_AGENT_UNAVAILABLE" + + // Deprecated: keep old names for compatibility. + DefaultBackupAgentSocketPath = DefaultDataManagementAgentSocketPath + BackupAgentSocketMissingReason = DataManagementAgentSocketMissingReason + BackupAgentUnavailableReason = DataManagementAgentUnavailableReason +) + +var ( + ErrDataManagementDeprecated = infraerrors.ServiceUnavailable( + DataManagementDeprecatedReason, + "data management feature is deprecated", + ) + ErrDataManagementAgentSocketMissing = infraerrors.ServiceUnavailable( + DataManagementAgentSocketMissingReason, + "data management agent socket is missing", + ) + ErrDataManagementAgentUnavailable = infraerrors.ServiceUnavailable( + DataManagementAgentUnavailableReason, + "data management agent is unavailable", + ) + + // Deprecated: keep old names for compatibility. + ErrBackupAgentSocketMissing = ErrDataManagementAgentSocketMissing + ErrBackupAgentUnavailable = ErrDataManagementAgentUnavailable +) + +type DataManagementAgentHealth struct { + Enabled bool + Reason string + SocketPath string + Agent *DataManagementAgentInfo +} + +type DataManagementAgentInfo struct { + Status string + Version string + UptimeSeconds int64 +} + +type DataManagementService struct { + socketPath string +} + +func NewDataManagementService() *DataManagementService { + return NewDataManagementServiceWithOptions(DefaultDataManagementAgentSocketPath, 500*time.Millisecond) +} + +func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService { + _ = dialTimeout + path := strings.TrimSpace(socketPath) + if path == "" { + path = DefaultDataManagementAgentSocketPath + } + return &DataManagementService{ + socketPath: path, + } +} + +func (s *DataManagementService) SocketPath() string { + if s == nil || strings.TrimSpace(s.socketPath) == "" { + return DefaultDataManagementAgentSocketPath + } + return s.socketPath +} + +func (s *DataManagementService) GetAgentHealth(ctx context.Context) DataManagementAgentHealth { + _ = ctx + return DataManagementAgentHealth{ + Enabled: false, + Reason: DataManagementDeprecatedReason, + SocketPath: s.SocketPath(), + } +} + +func (s *DataManagementService) EnsureAgentEnabled(ctx context.Context) error { + _ = ctx + return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()}) +} diff --git a/backend/internal/service/data_management_service_test.go b/backend/internal/service/data_management_service_test.go new file mode 100644 index 00000000..65489d2e --- /dev/null +++ b/backend/internal/service/data_management_service_test.go @@ -0,0 +1,37 @@ +package service + +import ( + "context" + "path/filepath" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestDataManagementService_GetAgentHealth_Deprecated(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "unused.sock") + svc := NewDataManagementServiceWithOptions(socketPath, 0) + health := svc.GetAgentHealth(context.Background()) + + require.False(t, health.Enabled) + require.Equal(t, DataManagementDeprecatedReason, health.Reason) + require.Equal(t, socketPath, health.SocketPath) + require.Nil(t, health.Agent) +} + +func TestDataManagementService_EnsureAgentEnabled_Deprecated(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "unused.sock") + svc := NewDataManagementServiceWithOptions(socketPath, 100) + err := svc.EnsureAgentEnabled(context.Background()) + require.Error(t, err) + + statusCode, status := infraerrors.ToHTTP(err) + require.Equal(t, 503, statusCode) + require.Equal(t, DataManagementDeprecatedReason, status.Reason) + require.Equal(t, socketPath, status.Metadata["socket_path"]) +} diff --git a/backend/internal/service/digest_session_store.go b/backend/internal/service/digest_session_store.go new file mode 100644 index 00000000..3ac08936 --- /dev/null +++ b/backend/internal/service/digest_session_store.go @@ -0,0 +1,69 @@ +package service + +import ( + "strconv" + "strings" + "time" + + gocache "github.com/patrickmn/go-cache" +) + +// digestSessionTTL 摘要会话默认 TTL +const digestSessionTTL = 5 * time.Minute + +// sessionEntry flat cache 条目 +type sessionEntry struct { + uuid string + accountID int64 +} + +// DigestSessionStore 内存摘要会话存储(flat cache 实现) +// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry +type DigestSessionStore struct { + cache *gocache.Cache +} + +// NewDigestSessionStore 创建内存摘要会话存储 +func NewDigestSessionStore() *DigestSessionStore { + return &DigestSessionStore{ + cache: gocache.New(digestSessionTTL, time.Minute), + } +} + +// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。 +func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) { + if digestChain == "" { + return + } + ns := buildNS(groupID, prefixHash) + s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration) + if oldDigestChain != "" && oldDigestChain != digestChain { + s.cache.Delete(ns + oldDigestChain) + } +} + +// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。 +func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" { + return "", 0, "", false + } + ns := buildNS(groupID, prefixHash) + chain := digestChain + for { + if val, ok := s.cache.Get(ns + chain); ok { + if e, ok := val.(*sessionEntry); ok { + return e.uuid, e.accountID, chain, true + } + } + i := strings.LastIndex(chain, "-") + if i < 0 { + return "", 0, "", false + } + chain = chain[:i] + } +} + +// buildNS 构建 namespace 前缀 +func buildNS(groupID int64, prefixHash string) string { + return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|" +} diff --git a/backend/internal/service/digest_session_store_test.go b/backend/internal/service/digest_session_store_test.go new file mode 100644 index 00000000..e505bf30 --- /dev/null +++ b/backend/internal/service/digest_session_store_test.go @@ -0,0 +1,312 @@ +//go:build unit + +package service + +import ( + "fmt" + "sync" + "testing" + "time" + + gocache "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDigestSessionStore_SaveAndFind(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "") + + uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) +} + +func TestDigestSessionStore_PrefixMatch(t *testing.T) { + store := NewDigestSessionStore() + + // 保存短链 + store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "") + + // 用长链查找,应前缀匹配到短链 + uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d") + require.True(t, found) + assert.Equal(t, "uuid-short", uuid) + assert.Equal(t, int64(10), accountID) + assert.Equal(t, "u:a-m:b", matchedChain) +} + +func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a", "uuid-1", 1, "") + store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "") + store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "") + + // 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的) + uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e") + require.True(t, found) + assert.Equal(t, "uuid-3", uuid) + assert.Equal(t, int64(3), accountID) + + // 查找中等长度,应匹配到 "u:a-m:b" + uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x") + require.True(t, found) + assert.Equal(t, "uuid-2", uuid) + assert.Equal(t, int64(2), accountID) +} + +func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) { + store := NewDigestSessionStore() + + // 第一轮:保存 "u:a-m:b" + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 第二轮:同一 uuid 保存更长的链,传入旧 chain + store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b") + + // 旧链 "u:a-m:b" 应已被删除 + _, _, _, found := store.Find(1, "prefix", "u:a-m:b") + assert.False(t, found, "old chain should be deleted") + + // 新链应能找到 + uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) +} + +func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) { + store := NewDigestSessionStore() + + // 相同系统提示词,不同用户提示词 + store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "") + store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "") + + uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) + + uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2") + require.True(t, found) + assert.Equal(t, "uuid-2", uuid) + assert.Equal(t, int64(200), accountID) +} + +func TestDigestSessionStore_NoMatch(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 完全不同的 chain + _, _, _, found := store.Find(1, "prefix", "u:x-m:y") + assert.False(t, found) +} + +func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "") + + // 不同 prefixHash 应隔离 + _, _, _, found := store.Find(1, "prefix2", "u:a-m:b") + assert.False(t, found) +} + +func TestDigestSessionStore_DifferentGroupID(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 不同 groupID 应隔离 + _, _, _, found := store.Find(2, "prefix", "u:a-m:b") + assert.False(t, found) +} + +func TestDigestSessionStore_EmptyDigestChain(t *testing.T) { + store := NewDigestSessionStore() + + // 空链不应保存 + store.Save(1, "prefix", "", "uuid-1", 100, "") + _, _, _, found := store.Find(1, "prefix", "") + assert.False(t, found) +} + +func TestDigestSessionStore_TTLExpiration(t *testing.T) { + store := &DigestSessionStore{ + cache: gocache.New(100*time.Millisecond, 50*time.Millisecond), + } + + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 立即应该能找到 + _, _, _, found := store.Find(1, "prefix", "u:a-m:b") + require.True(t, found) + + // 等待过期 + 清理周期 + time.Sleep(300 * time.Millisecond) + + // 过期后应找不到 + _, _, _, found = store.Find(1, "prefix", "u:a-m:b") + assert.False(t, found) +} + +func TestDigestSessionStore_ConcurrentSafety(t *testing.T) { + store := NewDigestSessionStore() + + var wg sync.WaitGroup + const goroutines = 50 + const operations = 100 + + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + go func(id int) { + defer wg.Done() + prefix := fmt.Sprintf("prefix-%d", id%5) + for i := 0; i < operations; i++ { + chain := fmt.Sprintf("u:%d-m:%d", id, i) + uuid := fmt.Sprintf("uuid-%d-%d", id, i) + store.Save(1, prefix, chain, uuid, int64(id), "") + store.Find(1, prefix, chain) + } + }(g) + } + wg.Wait() +} + +func TestDigestSessionStore_MultipleSessions(t *testing.T) { + store := NewDigestSessionStore() + + sessions := []struct { + chain string + uuid string + accountID int64 + }{ + {"u:session1", "uuid-1", 1}, + {"u:session2-m:reply2", "uuid-2", 2}, + {"u:session3-m:reply3-u:msg3", "uuid-3", 3}, + } + + for _, sess := range sessions { + store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "") + } + + // 验证每个会话都能正确查找 + for _, sess := range sessions { + uuid, accountID, _, found := store.Find(1, "prefix", sess.chain) + require.True(t, found, "should find session: %s", sess.chain) + assert.Equal(t, sess.uuid, uuid) + assert.Equal(t, sess.accountID, accountID) + } + + // 验证继续对话的场景 + uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg") + require.True(t, found) + assert.Equal(t, "uuid-2", uuid) + assert.Equal(t, int64(2), accountID) +} + +func TestDigestSessionStore_Performance1000Sessions(t *testing.T) { + store := NewDigestSessionStore() + + // 插入 1000 个会话 + for i := 0; i < 1000; i++ { + chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i) + store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "") + } + + // 查找性能测试 + start := time.Now() + const lookups = 10000 + for i := 0; i < lookups; i++ { + idx := i % 1000 + chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx) + _, _, _, found := store.Find(1, "prefix", chain) + assert.True(t, found) + } + elapsed := time.Since(start) + t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups) +} + +func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "") + + // 精确匹配 + _, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c") + require.True(t, found) + assert.Equal(t, "u:a-m:b-u:c", matchedChain) + + // 前缀匹配(截断后命中) + _, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e") + require.True(t, found) + assert.Equal(t, "u:a-m:b-u:c", matchedChain) +} + +func TestDigestSessionStore_CacheItemCountStable(t *testing.T) { + store := NewDigestSessionStore() + + // 模拟 100 个独立会话,每个进行 10 轮对话 + // 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key + for conv := 0; conv < 100; conv++ { + var prevMatchedChain string + for round := 0; round < 10; round++ { + chain := fmt.Sprintf("s:sys-u:user%d", conv) + for r := 0; r < round; r++ { + chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1) + } + uuid := fmt.Sprintf("uuid-conv%d", conv) + + _, _, matched, _ := store.Find(1, "prefix", chain) + store.Save(1, "prefix", chain, uuid, int64(conv), matched) + prevMatchedChain = matched + _ = prevMatchedChain + } + } + + // 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key + // 允许少量并发残留,但绝不能接近 100×10=1000 + itemCount := store.cache.ItemCount() + assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount) + t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount) +} + +func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) { + // 使用极短 TTL 验证大量写入后 cache 能被清理 + store := &DigestSessionStore{ + cache: gocache.New(100*time.Millisecond, 50*time.Millisecond), + } + + // 插入 500 个不同的 key(无 oldDigestChain,模拟最坏场景:全是新会话首轮) + for i := 0; i < 500; i++ { + chain := fmt.Sprintf("u:user%d", i) + store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "") + } + + assert.Equal(t, 500, store.cache.ItemCount()) + + // 等待 TTL + 清理周期 + time.Sleep(300 * time.Millisecond) + + assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up") +} + +func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) { + store := NewDigestSessionStore() + + // 保存 chain + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 用户重发相同消息:oldDigestChain == digestChain,不应删掉刚设置的 key + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b") + + // 仍然能找到 + uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 0295c23b..df213002 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -24,6 +24,7 @@ const ( PlatformOpenAI = domain.PlatformOpenAI PlatformGemini = domain.PlatformGemini PlatformAntigravity = domain.PlatformAntigravity + PlatformSora = domain.PlatformSora ) // Account type constants @@ -103,6 +104,7 @@ const ( SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" // OEM设置 + SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制) SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 @@ -111,12 +113,14 @@ const ( SettingKeyDocURL = "doc_url" // 文档链接 SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src) SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮 - SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示“购买订阅”页面入口 - SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src) + SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口 + SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src) + SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组) // 默认配置 - SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 - SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 + SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 + SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 + SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) // 管理员 API Key SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) @@ -160,12 +164,43 @@ const ( // SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation). SettingKeyOpsAdvancedSettings = "ops_advanced_settings" + // SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings. + SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config" + // ========================= // Stream Timeout Handling // ========================= // SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling. SettingKeyStreamTimeoutSettings = "stream_timeout_settings" + + // ========================= + // Sora S3 存储配置 + // ========================= + + SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储 + SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址 + SettingKeySoraS3Region = "sora_s3_region" // S3 区域 + SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称 + SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID + SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储) + SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀 + SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等) + SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选) + SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON) + + // ========================= + // Sora 用户存储配额 + // ========================= + + SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节) + + // ========================= + // Claude Code Version Check + // ========================= + + // SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查) + SettingKeyMinClaudeCodeVersion = "min_claude_code_version" ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/email_queue_service.go b/backend/internal/service/email_queue_service.go index 6c975c69..d8f0a518 100644 --- a/backend/internal/service/email_queue_service.go +++ b/backend/internal/service/email_queue_service.go @@ -3,9 +3,10 @@ package service import ( "context" "fmt" - "log" "sync" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // Task type constants @@ -56,7 +57,7 @@ func (s *EmailQueueService) start() { s.wg.Add(1) go s.worker(i) } - log.Printf("[EmailQueue] Started %d workers", s.workers) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Started %d workers", s.workers) } // worker 工作协程 @@ -68,7 +69,7 @@ func (s *EmailQueueService) worker(id int) { case task := <-s.taskChan: s.processTask(id, task) case <-s.stopChan: - log.Printf("[EmailQueue] Worker %d stopping", id) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d stopping", id) return } } @@ -82,18 +83,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) { switch task.TaskType { case TaskTypeVerifyCode: if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil { - log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err) } else { - log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email) } case TaskTypePasswordReset: if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil { - log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err) } else { - log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email) } default: - log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType) } } @@ -107,7 +108,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error { select { case s.taskChan <- task: - log.Printf("[EmailQueue] Enqueued verify code task for %s", email) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Enqueued verify code task for %s", email) return nil default: return fmt.Errorf("email queue is full") @@ -125,7 +126,7 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin select { case s.taskChan <- task: - log.Printf("[EmailQueue] Enqueued password reset task for %s", email) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Enqueued password reset task for %s", email) return nil default: return fmt.Errorf("email queue is full") @@ -136,5 +137,5 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin func (s *EmailQueueService) Stop() { close(s.stopChan) s.wg.Wait() - log.Println("[EmailQueue] All workers stopped") + logger.LegacyPrintf("service.email_queue", "%s", "[EmailQueue] All workers stopped") } diff --git a/backend/internal/service/error_passthrough_runtime.go b/backend/internal/service/error_passthrough_runtime.go index 65085d6f..011c3ce4 100644 --- a/backend/internal/service/error_passthrough_runtime.go +++ b/backend/internal/service/error_passthrough_runtime.go @@ -61,6 +61,11 @@ func applyErrorPassthroughRule( errMsg = *rule.CustomMessage } + // 命中 skip_monitoring 时在 context 中标记,供 ops_error_logger 跳过记录。 + if rule.SkipMonitoring { + c.Set(OpsSkipPassthroughKey, true) + } + // 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。 errType = "upstream_error" return status, errType, errMsg, true diff --git a/backend/internal/service/error_passthrough_runtime_test.go b/backend/internal/service/error_passthrough_runtime_test.go index 393e6e59..7032d15b 100644 --- a/backend/internal/service/error_passthrough_runtime_test.go +++ b/backend/internal/service/error_passthrough_runtime_test.go @@ -76,7 +76,7 @@ func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) { } account := &Account{ID: 12, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + _, err := svc.handleErrorResponse(context.Background(), resp, c, account, nil) require.Error(t, err) assert.Equal(t, http.StatusBadGateway, rec.Code) @@ -157,7 +157,7 @@ func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) { } account := &Account{ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + _, err := svc.handleErrorResponse(context.Background(), resp, c, account, nil) require.Error(t, err) assert.Equal(t, http.StatusTeapot, rec.Code) @@ -194,6 +194,63 @@ func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) { assert.Equal(t, "Gemini上游失败", errField["message"]) } +func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限") + rule.SkipMonitoring = true + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule}) + BindErrorPassthroughService(c, ruleSvc) + + _, _, _, matched := applyErrorPassthroughRule( + c, + PlatformAnthropic, + http.StatusBadRequest, + []byte(`{"error":{"message":"prompt is too long"}}`), + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ) + + assert.True(t, matched) + v, exists := c.Get(OpsSkipPassthroughKey) + assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true") + boolVal, ok := v.(bool) + assert.True(t, ok, "value should be bool") + assert.True(t, boolVal) +} + +func TestApplyErrorPassthroughRule_NoSkipMonitoringDoesNotSetContextKey(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限") + rule.SkipMonitoring = false + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule}) + BindErrorPassthroughService(c, ruleSvc) + + _, _, _, matched := applyErrorPassthroughRule( + c, + PlatformAnthropic, + http.StatusBadRequest, + []byte(`{"error":{"message":"prompt is too long"}}`), + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ) + + assert.True(t, matched) + _, exists := c.Get(OpsSkipPassthroughKey) + assert.False(t, exists, "OpsSkipPassthroughKey should NOT be set when skip_monitoring=false") +} + func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule { return &model.ErrorPassthroughRule{ ID: 1, diff --git a/backend/internal/service/error_passthrough_service.go b/backend/internal/service/error_passthrough_service.go index c3e0f630..26fdf9a7 100644 --- a/backend/internal/service/error_passthrough_service.go +++ b/backend/internal/service/error_passthrough_service.go @@ -2,13 +2,13 @@ package service import ( "context" - "log" "sort" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // ErrorPassthroughRepository 定义错误透传规则的数据访问接口 @@ -45,10 +45,20 @@ type ErrorPassthroughService struct { cache ErrorPassthroughCache // 本地内存缓存,用于快速匹配 - localCache []*model.ErrorPassthroughRule + localCache []*cachedPassthroughRule localCacheMu sync.RWMutex } +// cachedPassthroughRule 预计算的规则缓存,避免运行时重复 ToLower +type cachedPassthroughRule struct { + *model.ErrorPassthroughRule + lowerKeywords []string // 预计算的小写关键词 + lowerPlatforms []string // 预计算的小写平台 + errorCodeSet map[int]struct{} // 预计算的 error code set +} + +const maxBodyMatchLen = 8 << 10 // 8KB,错误信息不会在 8KB 之后才出现 + // NewErrorPassthroughService 创建错误透传规则服务 func NewErrorPassthroughService( repo ErrorPassthroughRepository, @@ -62,9 +72,9 @@ func NewErrorPassthroughService( // 启动时加载规则到本地缓存 ctx := context.Background() if err := svc.reloadRulesFromDB(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err) if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil { - log.Printf("[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr) } } @@ -72,7 +82,7 @@ func NewErrorPassthroughService( if cache != nil { cache.SubscribeUpdates(ctx, func() { if err := svc.refreshLocalCache(context.Background()); err != nil { - log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh cache on notification: %v", err) } }) } @@ -150,17 +160,19 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod return nil } - bodyStr := strings.ToLower(string(body)) + lowerPlatform := strings.ToLower(platform) + var bodyLower string // 延迟初始化,只在需要关键词匹配时计算 + var bodyLowerDone bool for _, rule := range rules { if !rule.Enabled { continue } - if !s.platformMatches(rule, platform) { + if !s.platformMatchesCached(rule, lowerPlatform) { continue } - if s.ruleMatches(rule, statusCode, bodyStr) { - return rule + if s.ruleMatchesOptimized(rule, statusCode, body, &bodyLower, &bodyLowerDone) { + return rule.ErrorPassthroughRule } } @@ -168,7 +180,7 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod } // getCachedRules 获取缓存的规则列表(按优先级排序) -func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule { +func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule { s.localCacheMu.RLock() rules := s.localCache s.localCacheMu.RUnlock() @@ -180,7 +192,7 @@ func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule // 如果本地缓存为空,尝试刷新 ctx := context.Background() if err := s.refreshLocalCache(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh cache: %v", err) return nil } @@ -213,7 +225,7 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error { // 更新 Redis 缓存 if s.cache != nil { if err := s.cache.Set(ctx, rules); err != nil { - log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to set cache: %v", err) } } @@ -223,17 +235,39 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error { return nil } -// setLocalCache 设置本地缓存 +// setLocalCache 设置本地缓存,预计算小写值和 set 以避免运行时重复计算 func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) { + cached := make([]*cachedPassthroughRule, len(rules)) + for i, r := range rules { + cr := &cachedPassthroughRule{ErrorPassthroughRule: r} + if len(r.Keywords) > 0 { + cr.lowerKeywords = make([]string, len(r.Keywords)) + for j, kw := range r.Keywords { + cr.lowerKeywords[j] = strings.ToLower(kw) + } + } + if len(r.Platforms) > 0 { + cr.lowerPlatforms = make([]string, len(r.Platforms)) + for j, p := range r.Platforms { + cr.lowerPlatforms[j] = strings.ToLower(p) + } + } + if len(r.ErrorCodes) > 0 { + cr.errorCodeSet = make(map[int]struct{}, len(r.ErrorCodes)) + for _, code := range r.ErrorCodes { + cr.errorCodeSet[code] = struct{}{} + } + } + cached[i] = cr + } + // 按优先级排序 - sorted := make([]*model.ErrorPassthroughRule, len(rules)) - copy(sorted, rules) - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].Priority < sorted[j].Priority + sort.Slice(cached, func(i, j int) bool { + return cached[i].Priority < cached[j].Priority }) s.localCacheMu.Lock() - s.localCache = sorted + s.localCache = cached s.localCacheMu.Unlock() } @@ -254,13 +288,13 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { // 先失效缓存,避免后续刷新读到陈旧规则。 if s.cache != nil { if err := s.cache.Invalidate(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to invalidate cache: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to invalidate cache: %v", err) } } // 刷新本地缓存 if err := s.reloadRulesFromDB(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh local cache: %v", err) // 刷新失败时清空本地缓存,避免继续使用陈旧规则。 s.clearLocalCache() } @@ -268,67 +302,84 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { // 通知其他实例 if s.cache != nil { if err := s.cache.NotifyUpdate(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to notify cache update: %v", err) } } } -// platformMatches 检查平台是否匹配 -func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool { - // 如果没有配置平台限制,则匹配所有平台 - if len(rule.Platforms) == 0 { +// ensureBodyLower 延迟初始化 body 的小写版本,只做一次转换,限制 8KB +func ensureBodyLower(body []byte, bodyLower *string, done *bool) string { + if *done { + return *bodyLower + } + b := body + if len(b) > maxBodyMatchLen { + b = b[:maxBodyMatchLen] + } + *bodyLower = strings.ToLower(string(b)) + *done = true + return *bodyLower +} + +// platformMatchesCached 使用预计算的小写平台检查是否匹配 +func (s *ErrorPassthroughService) platformMatchesCached(rule *cachedPassthroughRule, lowerPlatform string) bool { + if len(rule.lowerPlatforms) == 0 { return true } - - platform = strings.ToLower(platform) - for _, p := range rule.Platforms { - if strings.ToLower(p) == platform { + for _, p := range rule.lowerPlatforms { + if p == lowerPlatform { return true } } - return false } -// ruleMatches 检查规则是否匹配 -func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool { - hasErrorCodes := len(rule.ErrorCodes) > 0 - hasKeywords := len(rule.Keywords) > 0 +// ruleMatchesOptimized 优化的规则匹配,支持短路和延迟 body 转换 +func (s *ErrorPassthroughService) ruleMatchesOptimized(rule *cachedPassthroughRule, statusCode int, body []byte, bodyLower *string, bodyLowerDone *bool) bool { + hasErrorCodes := len(rule.errorCodeSet) > 0 + hasKeywords := len(rule.lowerKeywords) > 0 - // 如果没有配置任何条件,不匹配 if !hasErrorCodes && !hasKeywords { return false } - codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode) - keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords) + codeMatch := !hasErrorCodes || s.containsIntSet(rule.errorCodeSet, statusCode) if rule.MatchMode == model.MatchModeAll { - // "all" 模式:所有配置的条件都必须满足 - return codeMatch && keywordMatch + // "all" 模式:所有配置的条件都必须满足,短路 + if hasErrorCodes && !codeMatch { + return false + } + if hasKeywords { + return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords) + } + return codeMatch } - // "any" 模式:任一条件满足即可 + // "any" 模式:任一条件满足即可,短路 if hasErrorCodes && hasKeywords { - return codeMatch || keywordMatch + if codeMatch { + return true + } + return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords) } - return codeMatch && keywordMatch + // 只配置了一种条件 + if hasKeywords { + return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords) + } + return codeMatch } -// containsInt 检查切片是否包含指定整数 -func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool { - for _, v := range slice { - if v == val { - return true - } - } - return false -} - -// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写) -func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool { - for _, kw := range keywords { - if strings.Contains(bodyLower, strings.ToLower(kw)) { +// containsIntSet 使用 map 查找替代线性扫描 +func (s *ErrorPassthroughService) containsIntSet(set map[int]struct{}, val int) bool { + _, ok := set[val] + return ok +} + +// containsAnyKeywordCached 使用预计算的小写关键词检查匹配 +func (s *ErrorPassthroughService) containsAnyKeywordCached(bodyLower string, lowerKeywords []string) bool { + for _, kw := range lowerKeywords { + if strings.Contains(bodyLower, kw) { return true } } diff --git a/backend/internal/service/error_passthrough_service_test.go b/backend/internal/service/error_passthrough_service_test.go index 74c98d86..96ddd637 100644 --- a/backend/internal/service/error_passthrough_service_test.go +++ b/backend/internal/service/error_passthrough_service_test.go @@ -145,32 +145,58 @@ func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughServic return svc } +// newCachedRuleForTest 从 model.ErrorPassthroughRule 创建 cachedPassthroughRule(测试用) +func newCachedRuleForTest(rule *model.ErrorPassthroughRule) *cachedPassthroughRule { + cr := &cachedPassthroughRule{ErrorPassthroughRule: rule} + if len(rule.Keywords) > 0 { + cr.lowerKeywords = make([]string, len(rule.Keywords)) + for j, kw := range rule.Keywords { + cr.lowerKeywords[j] = strings.ToLower(kw) + } + } + if len(rule.Platforms) > 0 { + cr.lowerPlatforms = make([]string, len(rule.Platforms)) + for j, p := range rule.Platforms { + cr.lowerPlatforms[j] = strings.ToLower(p) + } + } + if len(rule.ErrorCodes) > 0 { + cr.errorCodeSet = make(map[int]struct{}, len(rule.ErrorCodes)) + for _, code := range rule.ErrorCodes { + cr.errorCodeSet[code] = struct{}{} + } + } + return cr +} + // ============================================================================= -// 测试 ruleMatches 核心匹配逻辑 +// 测试 ruleMatchesOptimized 核心匹配逻辑 // ============================================================================= func TestRuleMatches_NoConditions(t *testing.T) { // 没有配置任何条件时,不应该匹配 svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{}, Keywords: []string{}, MatchMode: model.MatchModeAny, - } + }) - assert.False(t, svc.ruleMatches(rule, 422, "some error message"), + var bodyLower string + var bodyLowerDone bool + assert.False(t, svc.ruleMatchesOptimized(rule, 422, []byte("some error message"), &bodyLower, &bodyLowerDone), "没有配置条件时不应该匹配") } func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{422, 400}, Keywords: []string{}, MatchMode: model.MatchModeAny, - } + }) tests := []struct { name string @@ -186,7 +212,9 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := svc.ruleMatches(rule, tt.statusCode, tt.body) + var bodyLower string + var bodyLowerDone bool + result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone) assert.Equal(t, tt.expected, result) }) } @@ -194,12 +222,12 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{}, Keywords: []string{"context limit", "model not supported"}, MatchMode: model.MatchModeAny, - } + }) tests := []struct { name string @@ -210,16 +238,14 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { {"关键词匹配 context limit", 500, "error: context limit reached", true}, {"关键词匹配 model not supported", 400, "the model not supported here", true}, {"关键词不匹配", 422, "some other error", false}, - // 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的 - // 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches - {"关键词大小写 - 输入已小写", 500, "context limit exceeded", true}, + {"关键词大小写 - 自动转换", 500, "Context Limit exceeded", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 模拟 MatchRule 的行为:先转换为小写 - bodyLower := strings.ToLower(tt.body) - result := svc.ruleMatches(rule, tt.statusCode, bodyLower) + var bodyLower string + var bodyLowerDone bool + result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone) assert.Equal(t, tt.expected, result) }) } @@ -228,12 +254,12 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { // any 模式:错误码 OR 关键词 svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{422, 400}, Keywords: []string{"context limit"}, MatchMode: model.MatchModeAny, - } + }) tests := []struct { name string @@ -274,7 +300,9 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := svc.ruleMatches(rule, tt.statusCode, tt.body) + var bodyLower string + var bodyLowerDone bool + result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone) assert.Equal(t, tt.expected, result, tt.reason) }) } @@ -283,12 +311,12 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { func TestRuleMatches_BothConditions_AllMode(t *testing.T) { // all 模式:错误码 AND 关键词 svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{422, 400}, Keywords: []string{"context limit"}, MatchMode: model.MatchModeAll, - } + }) tests := []struct { name string @@ -329,14 +357,16 @@ func TestRuleMatches_BothConditions_AllMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := svc.ruleMatches(rule, tt.statusCode, tt.body) + var bodyLower string + var bodyLowerDone bool + result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone) assert.Equal(t, tt.expected, result, tt.reason) }) } } // ============================================================================= -// 测试 platformMatches 平台匹配逻辑 +// 测试 platformMatchesCached 平台匹配逻辑 // ============================================================================= func TestPlatformMatches(t *testing.T) { @@ -394,10 +424,10 @@ func TestPlatformMatches(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Platforms: tt.rulePlatforms, - } - result := svc.platformMatches(rule, tt.requestPlatform) + }) + result := svc.platformMatchesCached(rule, strings.ToLower(tt.requestPlatform)) assert.Equal(t, tt.expected, result) }) } diff --git a/backend/internal/service/error_policy_integration_test.go b/backend/internal/service/error_policy_integration_test.go new file mode 100644 index 00000000..a8b42a2c --- /dev/null +++ b/backend/internal/service/error_policy_integration_test.go @@ -0,0 +1,472 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mocks (scoped to this file by naming convention) +// --------------------------------------------------------------------------- + +// epFixedUpstream returns a fixed response for every request. +type epFixedUpstream struct { + statusCode int + body string + calls int +} + +func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + u.calls++ + return &http.Response{ + StatusCode: u.statusCode, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(u.body)), + }, nil +} + +func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +// epAccountRepo records SetTempUnschedulable / SetError calls. +type epAccountRepo struct { + mockAccountRepoForGemini + tempCalls int + setErrCalls int +} + +func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error { + r.tempCalls++ + return nil +} + +func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error { + r.setErrCalls++ + return nil +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func saveAndSetBaseURLs(t *testing.T) { + t.Helper() + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) + oldAvail := antigravity.DefaultURLAvailability + antigravity.BaseURLs = []string{"https://ep-test.example"} + antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute) + t.Cleanup(func() { + antigravity.BaseURLs = oldBaseURLs + antigravity.DefaultURLAvailability = oldAvail + }) +} + +func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams { + return antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[ep-test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + requestedModel: "claude-sonnet-4-5", + handleError: handleError, + } +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_CustomErrorCodes +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) { + tests := []struct { + name string + upstreamStatus int + upstreamBody string + customCodes []any + expectHandleError int + expectUpstream int + expectStatusCode int + }{ + { + name: "429_in_custom_codes_matched", + upstreamStatus: 429, + upstreamBody: `{"error":"rate limited"}`, + customCodes: []any{float64(429)}, + expectHandleError: 1, + expectUpstream: 1, + expectStatusCode: 429, + }, + { + name: "429_not_in_custom_codes_skipped", + upstreamStatus: 429, + upstreamBody: `{"error":"rate limited"}`, + customCodes: []any{float64(500)}, + expectHandleError: 0, + expectUpstream: 1, + expectStatusCode: 500, + }, + { + name: "500_in_custom_codes_matched", + upstreamStatus: 500, + upstreamBody: `{"error":"internal"}`, + customCodes: []any{float64(500)}, + expectHandleError: 1, + expectUpstream: 1, + expectStatusCode: 500, + }, + { + name: "500_not_in_custom_codes_skipped", + upstreamStatus: 500, + upstreamBody: `{"error":"internal"}`, + customCodes: []any{float64(429)}, + expectHandleError: 0, + expectUpstream: 1, + expectStatusCode: 500, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 100, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": tt.customCodes, + }, + } + + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + var handleErrorCount int + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.resp) + defer func() { _ = result.resp.Body.Close() }() + + require.Equal(t, tt.expectStatusCode, result.resp.StatusCode) + require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count") + require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count") + }) + } +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_TempUnschedulable +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) { + tempRulesAccount := func(rules []any) *Account { + return &Account{ + ID: 200, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": rules, + }, + } + } + + overloadedRule := map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + } + + rateLimitRule := map[string]any{ + "error_code": float64(429), + "keywords": []any{"rate limited keyword"}, + "duration_minutes": float64(5), + } + + t.Run("503_overloaded_matches_rule", func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := tempRulesAccount([]any{overloadedRule}) + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + t.Error("handleError should not be called for temp unschedulable") + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.Nil(t, result) + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, 1, upstream.calls, "should not retry") + }) + + t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := tempRulesAccount([]any{rateLimitRule}) + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + t.Error("handleError should not be called for temp unschedulable") + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.Nil(t, result) + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, 1, upstream.calls, "should not retry") + }) + + t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 503, body: `random`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := tempRulesAccount([]any{overloadedRule}) + + // Use a short-lived context: the backoff sleep (~1s) will be + // interrupted, proving the code entered the default retry path + // instead of breaking early via error policy. + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + return nil + }) + p.ctx = ctx + + result, err := svc.antigravityRetryLoop(p) + + // Context cancellation during backoff proves default retry was entered + require.Nil(t, result) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once") + }) +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_NilRateLimitService +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`} + // rateLimitService is nil — must not panic + svc := &AntigravityGatewayService{rateLimitService: nil} + + account := &Account{ + ID: 300, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + return nil + }) + p.ctx = ctx + + // Should not panic; enters the default retry path (eventually times out) + result, err := svc.antigravityRetryLoop(p) + + require.Nil(t, result) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.GreaterOrEqual(t, upstream.calls, 1) +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + // Plain OAuth account with no error policy configured + account := &Account{ + ID: 400, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + var handleErrorCount int + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.resp) + defer func() { _ = result.resp.Body.Close() }() + + require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode) + require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries") + require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted") +} + +// --------------------------------------------------------------------------- +// epTrackingRepo — records SetRateLimited / SetError calls for verification. +// --------------------------------------------------------------------------- + +type epTrackingRepo struct { + mockAccountRepoForGemini + rateLimitedCalls int + rateLimitedID int64 + setErrCalls int + setErrID int64 + tempCalls int +} + +func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error { + r.rateLimitedCalls++ + r.rateLimitedID = id + return nil +} + +func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error { + r.setErrCalls++ + r.setErrID = id + return nil +} + +func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error { + r.tempCalls++ + return nil +} + +// --------------------------------------------------------------------------- +// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit +// +// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码), +// 当上游返回 429/500/503/401 时: +// - 返回给客户端的状态码必须是 500(而不是透传原始状态码) +// - 不调用 SetRateLimited(不进入限流状态) +// - 不调用 SetError(不停止调度) +// - 不调用 handleError +// --------------------------------------------------------------------------- + +func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) { + errorCodes := []int{429, 500, 503, 401, 403} + + for _, upstreamStatus := range errorCodes { + t.Run(http.StatusText(upstreamStatus), func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{ + statusCode: upstreamStatus, + body: `{"error":"some upstream error"}`, + } + repo := &epTrackingRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := &Account{ + ID: 500, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(599)}, + }, + } + + var handleErrorCount int + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + // 不应返回 error(Skipped 不触发账号切换) + require.NoError(t, err, "should not return error") + require.NotNil(t, result, "result should not be nil") + require.NotNil(t, result.resp, "response should not be nil") + defer func() { _ = result.resp.Body.Close() }() + + // 状态码必须是 500(不透传原始状态码) + require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode, + "skipped error should return 500, not %d", upstreamStatus) + + // 不调用 handleError + require.Equal(t, 0, handleErrorCount, + "handleError should NOT be called for skipped errors") + + // 不标记限流 + require.Equal(t, 0, repo.rateLimitedCalls, + "SetRateLimited should NOT be called for skipped errors") + + // 不停止调度 + require.Equal(t, 0, repo.setErrCalls, + "SetError should NOT be called for skipped errors") + + // 只调用一次上游(不重试) + require.Equal(t, 1, upstream.calls, + "should call upstream exactly once (no retry)") + }) + } +} diff --git a/backend/internal/service/error_policy_test.go b/backend/internal/service/error_policy_test.go new file mode 100644 index 00000000..9d7d025e --- /dev/null +++ b/backend/internal/service/error_policy_test.go @@ -0,0 +1,295 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function +// --------------------------------------------------------------------------- + +func TestCheckErrorPolicy(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + body []byte + expected ErrorPolicyResult + }{ + { + name: "no_policy_oauth_returns_none", + account: &Account{ + ID: 1, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + // no custom error codes, no temp rules + }, + statusCode: 500, + body: []byte(`"error"`), + expected: ErrorPolicyNone, + }, + { + name: "custom_error_codes_hit_returns_matched", + account: &Account{ + ID: 2, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429), float64(500)}, + }, + }, + statusCode: 500, + body: []byte(`"error"`), + expected: ErrorPolicyMatched, + }, + { + name: "custom_error_codes_miss_returns_skipped", + account: &Account{ + ID: 3, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429), float64(500)}, + }, + }, + statusCode: 503, + body: []byte(`"error"`), + expected: ErrorPolicySkipped, + }, + { + name: "temp_unschedulable_hit_returns_temp_unscheduled", + account: &Account{ + ID: 4, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + "description": "overloaded rule", + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded service`), + expected: ErrorPolicyTempUnscheduled, + }, + { + name: "temp_unschedulable_body_miss_returns_none", + account: &Account{ + ID: 5, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + "description": "overloaded rule", + }, + }, + }, + }, + statusCode: 503, + body: []byte(`random msg`), + expected: ErrorPolicyNone, + }, + { + name: "custom_error_codes_override_temp_unschedulable", + account: &Account{ + ID: 6, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(503)}, + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + "description": "overloaded rule", + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded`), + expected: ErrorPolicyMatched, // custom codes take precedence + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &errorPolicyRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body) + require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult") + }) + } +} + +// --------------------------------------------------------------------------- +// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method +// --------------------------------------------------------------------------- + +func TestApplyErrorPolicy(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + body []byte + expectedHandled bool + expectedStatus int // expected outStatus + expectedSwitchErr bool // expect *AntigravityAccountSwitchError + handleErrorCalls int + }{ + { + name: "none_not_handled", + account: &Account{ + ID: 10, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + }, + statusCode: 500, + body: []byte(`"error"`), + expectedHandled: false, + expectedStatus: 500, // passthrough + handleErrorCalls: 0, + }, + { + name: "skipped_handled_no_handleError", + account: &Account{ + ID: 11, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 500, // not in custom codes + body: []byte(`"error"`), + expectedHandled: true, + expectedStatus: http.StatusInternalServerError, // skipped → 500 + handleErrorCalls: 0, + }, + { + name: "matched_handled_calls_handleError", + account: &Account{ + ID: 12, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(500)}, + }, + }, + statusCode: 500, + body: []byte(`"error"`), + expectedHandled: true, + expectedStatus: 500, // matched → original status + handleErrorCalls: 1, + }, + { + name: "temp_unscheduled_returns_switch_error", + account: &Account{ + ID: 13, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded`), + expectedHandled: true, + expectedStatus: 503, // temp_unscheduled → original status + expectedSwitchErr: true, + handleErrorCalls: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &errorPolicyRepoStub{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{ + rateLimitService: rlSvc, + } + + var handleErrorCount int + p := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: tt.account, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }, + isStickySession: true, + } + + handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body) + + require.Equal(t, tt.expectedHandled, handled, "handled mismatch") + require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch") + require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch") + + if tt.expectedSwitchErr { + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, retErr, &switchErr) + require.Equal(t, tt.account.ID, switchErr.OriginalAccountID) + } else { + require.NoError(t, retErr) + } + }) + } +} + +// --------------------------------------------------------------------------- +// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests +// --------------------------------------------------------------------------- + +type errorPolicyRepoStub struct { + mockAccountRepoForGemini + tempCalls int + setErrCalls int + lastErrorMsg string +} + +func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + r.tempCalls++ + return nil +} + +func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { + r.setErrCalls++ + r.lastErrorMsg = errorMsg + return nil +} diff --git a/backend/internal/service/gateway_account_selection_test.go b/backend/internal/service/gateway_account_selection_test.go new file mode 100644 index 00000000..0a82fade --- /dev/null +++ b/backend/internal/service/gateway_account_selection_test.go @@ -0,0 +1,206 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// --- helpers --- + +func testTimePtr(t time.Time) *time.Time { return &t } + +func makeAccWithLoad(id int64, priority int, loadRate int, lastUsed *time.Time, accType string) accountWithLoad { + return accountWithLoad{ + account: &Account{ + ID: id, + Priority: priority, + LastUsedAt: lastUsed, + Type: accType, + Schedulable: true, + Status: StatusActive, + }, + loadInfo: &AccountLoadInfo{ + AccountID: id, + CurrentConcurrency: 0, + LoadRate: loadRate, + }, + } +} + +// --- sortAccountsByPriorityAndLastUsed --- + +func TestSortAccountsByPriorityAndLastUsed_ByPriority(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 5, LastUsedAt: testTimePtr(now)}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 3, Priority: 3, LastUsedAt: testTimePtr(now)}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(2), accounts[0].ID, "优先级最低的排第一") + require.Equal(t, int64(3), accounts[1].ID) + require.Equal(t, int64(1), accounts[2].ID) +} + +func TestSortAccountsByPriorityAndLastUsed_SamePriorityByLastUsed(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))}, + {ID: 3, Priority: 1, LastUsedAt: nil}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(3), accounts[0].ID, "nil LastUsedAt 排最前") + require.Equal(t, int64(2), accounts[1].ID, "更早使用的排前面") + require.Equal(t, int64(1), accounts[2].ID) +} + +func TestSortAccountsByPriorityAndLastUsed_PreferOAuth(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeOAuth}, + } + sortAccountsByPriorityAndLastUsed(accounts, true) + require.Equal(t, int64(2), accounts[0].ID, "preferOAuth 时 OAuth 账号排前面") +} + +func TestSortAccountsByPriorityAndLastUsed_StableSort(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 3, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + } + + // sortAccountsByPriorityAndLastUsed 内部会在同组(Priority+LastUsedAt)内做随机打散, + // 因此这里不再断言“稳定排序”。我们只验证: + // 1) 元素集合不变;2) 多次运行能产生不同的顺序。 + seenFirst := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + sortAccountsByPriorityAndLastUsed(cpy, false) + seenFirst[cpy[0].ID] = true + + ids := map[int64]bool{} + for _, a := range cpy { + ids[a.ID] = true + } + require.True(t, ids[1] && ids[2] && ids[3]) + } + require.GreaterOrEqual(t, len(seenFirst), 2, "同组账号应能被随机打散") +} + +func TestSortAccountsByPriorityAndLastUsed_MixedPriorityAndTime(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 2, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 3, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))}, + {ID: 4, Priority: 2, LastUsedAt: testTimePtr(now.Add(-2 * time.Hour))}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + // 优先级1排前:nil < earlier + require.Equal(t, int64(3), accounts[0].ID, "优先级1 + 更早") + require.Equal(t, int64(2), accounts[1].ID, "优先级1 + 现在") + // 优先级2排后:nil < time + require.Equal(t, int64(1), accounts[2].ID, "优先级2 + nil") + require.Equal(t, int64(4), accounts[3].ID, "优先级2 + 有时间") +} + +// --- filterByMinPriority --- + +func TestFilterByMinPriority_Empty(t *testing.T) { + result := filterByMinPriority(nil) + require.Nil(t, result) +} + +func TestFilterByMinPriority_SelectsMinPriority(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 5, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 20, nil, AccountTypeAPIKey), + makeAccWithLoad(4, 2, 10, nil, AccountTypeAPIKey), + } + result := filterByMinPriority(accounts) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID) + require.Equal(t, int64(3), result[1].account.ID) +} + +// --- filterByMinLoadRate --- + +func TestFilterByMinLoadRate_Empty(t *testing.T) { + result := filterByMinLoadRate(nil) + require.Nil(t, result) +} + +func TestFilterByMinLoadRate_SelectsMinLoadRate(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 30, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(4, 1, 20, nil, AccountTypeAPIKey), + } + result := filterByMinLoadRate(accounts) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID) + require.Equal(t, int64(3), result[1].account.ID) +} + +// --- selectByLRU --- + +func TestSelectByLRU_Empty(t *testing.T) { + result := selectByLRU(nil, false) + require.Nil(t, result) +} + +func TestSelectByLRU_Single(t *testing.T) { + accounts := []accountWithLoad{makeAccWithLoad(1, 1, 10, nil, AccountTypeAPIKey)} + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(1), result.account.ID) +} + +func TestSelectByLRU_NilLastUsedAtWins(t *testing.T) { + now := time.Now() + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey), + } + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(2), result.account.ID) +} + +func TestSelectByLRU_EarliestTimeWins(t *testing.T) { + now := time.Now() + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey), + makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(-2*time.Hour)), AccountTypeAPIKey), + } + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(3), result.account.ID) +} + +func TestSelectByLRU_TiePreferOAuth(t *testing.T) { + now := time.Now() + // 账号 1/2 LastUsedAt 相同,且同为最小值。 + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, testTimePtr(now), AccountTypeOAuth), + makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(1*time.Hour)), AccountTypeAPIKey), + } + for i := 0; i < 50; i++ { + result := selectByLRU(accounts, true) + require.NotNil(t, result) + require.Equal(t, AccountTypeOAuth, result.account.Type) + require.Equal(t, int64(2), result.account.ID) + } +} diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_benchmark_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_benchmark_test.go new file mode 100644 index 00000000..37fd709f --- /dev/null +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_benchmark_test.go @@ -0,0 +1,56 @@ +package service + +import "testing" + +func BenchmarkGatewayService_ParseSSEUsage_MessageStart(b *testing.B) { + svc := &GatewayService{} + data := `{"type":"message_start","message":{"usage":{"input_tokens":123,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage := &ClaudeUsage{} + svc.parseSSEUsage(data, usage) + } +} + +func BenchmarkGatewayService_ParseSSEUsagePassthrough_MessageStart(b *testing.B) { + svc := &GatewayService{} + data := `{"type":"message_start","message":{"usage":{"input_tokens":123,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage := &ClaudeUsage{} + svc.parseSSEUsagePassthrough(data, usage) + } +} + +func BenchmarkGatewayService_ParseSSEUsage_MessageDelta(b *testing.B) { + svc := &GatewayService{} + data := `{"type":"message_delta","usage":{"output_tokens":456,"cache_creation_input_tokens":30,"cache_read_input_tokens":7,"cached_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":10,"ephemeral_1h_input_tokens":20}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage := &ClaudeUsage{} + svc.parseSSEUsage(data, usage) + } +} + +func BenchmarkGatewayService_ParseSSEUsagePassthrough_MessageDelta(b *testing.B) { + svc := &GatewayService{} + data := `{"type":"message_delta","usage":{"output_tokens":456,"cache_creation_input_tokens":30,"cache_read_input_tokens":7,"cached_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":10,"ephemeral_1h_input_tokens":20}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage := &ClaudeUsage{} + svc.parseSSEUsagePassthrough(data, usage) + } +} + +func BenchmarkParseClaudeUsageFromResponseBody(b *testing.B) { + body := []byte(`{"id":"msg_123","type":"message","usage":{"input_tokens":123,"output_tokens":456,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = parseClaudeUsageFromResponseBody(body) + } +} diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go new file mode 100644 index 00000000..f8c0ecda --- /dev/null +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -0,0 +1,875 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +type anthropicHTTPUpstreamRecorder struct { + lastReq *http.Request + lastBody []byte + resp *http.Response + err error +} + +func newAnthropicAPIKeyAccountForTest() *Account { + return &Account{ + ID: 201, + Name: "anthropic-apikey-pass-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-anthropic-key", + "base_url": "https://api.anthropic.com", + }, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + Status: StatusActive, + Schedulable: true, + } +} + +func (u *anthropicHTTPUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + u.lastReq = req + if req != nil && req.Body != nil { + b, _ := io.ReadAll(req.Body) + u.lastBody = b + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewReader(b)) + } + if u.err != nil { + return nil, u.err + } + return u.resp, nil +} + +func (u *anthropicHTTPUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +type streamReadCloser struct { + payload []byte + sent bool + err error +} + +func (r *streamReadCloser) Read(p []byte) (int, error) { + if !r.sent { + r.sent = true + n := copy(p, r.payload) + return n, nil + } + if r.err != nil { + return 0, r.err + } + return 0, io.EOF +} + +func (r *streamReadCloser) Close() error { return nil } + +type failWriteResponseWriter struct { + gin.ResponseWriter +} + +func (w *failWriteResponseWriter) Write(data []byte) (int, error) { + return 0, errors.New("client disconnected") +} + +func (w *failWriteResponseWriter) WriteString(_ string) (int, error) { + return 0, errors.New("client disconnected") +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAndAuthReplacement(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Request.Header.Set("User-Agent", "claude-cli/1.0.0") + c.Request.Header.Set("Authorization", "Bearer inbound-token") + c.Request.Header.Set("X-Api-Key", "inbound-api-key") + c.Request.Header.Set("X-Goog-Api-Key", "inbound-goog-key") + c.Request.Header.Set("Cookie", "secret=1") + c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14") + + body := []byte(`{"model":"claude-3-7-sonnet-20250219","stream":true,"system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + parsed := &ParsedRequest{ + Body: body, + Model: "claude-3-7-sonnet-20250219", + Stream: true, + } + + upstreamSSE := strings.Join([]string{ + `data: {"type":"message_start","message":{"usage":{"input_tokens":9,"cached_tokens":7}}}`, + "", + `data: {"type":"message_delta","usage":{"output_tokens":3}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "x-request-id": []string{"rid-anthropic-pass"}, + "Set-Cookie": []string{"secret=upstream"}, + }, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + deferredService: &DeferredService{}, + billingCacheService: nil, + } + + account := &Account{ + ID: 101, + Name: "anthropic-apikey-pass", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-anthropic-key", + "base_url": "https://api.anthropic.com", + "model_mapping": map[string]any{"claude-3-7-sonnet-20250219": "claude-3-haiku-20240307"}, + }, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + Status: StatusActive, + Schedulable: true, + } + + result, err := svc.Forward(context.Background(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + + require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体") + require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String()) + + require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) + require.Empty(t, upstream.lastReq.Header.Get("authorization")) + require.Empty(t, upstream.lastReq.Header.Get("x-goog-api-key")) + require.Empty(t, upstream.lastReq.Header.Get("cookie")) + require.Equal(t, "2023-06-01", upstream.lastReq.Header.Get("anthropic-version")) + require.Equal(t, "interleaved-thinking-2025-05-14", upstream.lastReq.Header.Get("anthropic-beta")) + require.Empty(t, upstream.lastReq.Header.Get("x-stainless-lang"), "API Key 透传不应注入 OAuth 指纹头") + + require.Contains(t, rec.Body.String(), `"cached_tokens":7`) + require.NotContains(t, rec.Body.String(), `"cache_read_input_tokens":7`, "透传输出不应被网关改写") + require.Equal(t, 7, result.Usage.CacheReadInputTokens, "计费 usage 解析应保留 cached_tokens 兼容") + require.Empty(t, rec.Header().Get("Set-Cookie"), "响应头应经过安全过滤") + rawBody, ok := c.Get(OpsUpstreamRequestBodyKey) + require.True(t, ok) + bodyBytes, ok := rawBody.([]byte) + require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝") + require.Equal(t, body, bodyBytes) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + c.Request.Header.Set("Authorization", "Bearer inbound-token") + c.Request.Header.Set("X-Api-Key", "inbound-api-key") + c.Request.Header.Set("Cookie", "secret=1") + + body := []byte(`{"model":"claude-3-5-sonnet-latest","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}],"thinking":{"type":"enabled"}}`) + parsed := &ParsedRequest{ + Body: body, + Model: "claude-3-5-sonnet-latest", + } + + upstreamRespBody := `{"input_tokens":42}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-count"}, + "Set-Cookie": []string{"secret=upstream"}, + }, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 102, + Name: "anthropic-apikey-pass-count", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-anthropic-key", + "base_url": "https://api.anthropic.com", + "model_mapping": map[string]any{"claude-3-5-sonnet-latest": "claude-3-opus-20240229"}, + }, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + Status: StatusActive, + Schedulable: true, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + + require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体") + require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) + require.Empty(t, upstream.lastReq.Header.Get("authorization")) + require.Empty(t, upstream.lastReq.Header.Get("cookie")) + require.Equal(t, http.StatusOK, rec.Code) + require.JSONEq(t, upstreamRespBody, rec.Body.String()) + require.Empty(t, rec.Header().Get("Set-Cookie")) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + statusCode int + respBody string + wantPassthrough bool + }{ + { + name: "404 endpoint not found passes through as 404", + statusCode: http.StatusNotFound, + respBody: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`, + wantPassthrough: true, + }, + { + name: "404 generic not found does not passthrough", + statusCode: http.StatusNotFound, + respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`, + wantPassthrough: false, + }, + { + name: "400 Invalid URL does not passthrough", + statusCode: http.StatusBadRequest, + respBody: `{"error":{"message":"Invalid URL (POST /v1/messages/count_tokens)","type":"invalid_request_error"}}`, + wantPassthrough: false, + }, + { + name: "400 model error does not passthrough", + statusCode: http.StatusBadRequest, + respBody: `{"error":{"message":"model not found: claude-unknown","type":"invalid_request_error"}}`, + wantPassthrough: false, + }, + { + name: "500 internal error does not passthrough", + statusCode: http.StatusInternalServerError, + respBody: `{"error":{"message":"internal error","type":"api_error"}}`, + wantPassthrough: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + + body := []byte(`{"model":"claude-sonnet-4-5-20250929","messages":[{"role":"user","content":"hi"}]}`) + parsed := &ParsedRequest{Body: body, Model: "claude-sonnet-4-5-20250929"} + + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: tt.statusCode, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(tt.respBody)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }, + httpUpstream: upstream, + rateLimitService: nil, + } + + account := &Account{ + ID: 200, + Name: "proxy-acc", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-proxy", + "base_url": "https://proxy.example.com", + }, + Extra: map[string]any{"anthropic_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + + if tt.wantPassthrough { + // 返回 nil(不记录为错误),HTTP 状态码 404 + Anthropic 错误体 + require.NoError(t, err) + require.Equal(t, http.StatusNotFound, rec.Code) + var errResp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &errResp)) + require.Equal(t, "error", errResp["type"]) + errObj, ok := errResp["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "not_found_error", errObj["type"]) + } else { + require.Error(t, err) + require.Equal(t, tt.statusCode, rec.Code) + } + }) + } +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_BuildRequestRejectsInvalidBaseURL(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + Enabled: false, + }, + }, + }, + } + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "k", + "base_url": "://invalid-url", + }, + } + + _, err := svc.buildUpstreamRequestAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "k") + require.Error(t, err) +} + +func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }, + } + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + } + + require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled()) + + req, err := svc.buildUpstreamRequest(context.Background(), c, account, []byte(`{"model":"claude-3-7-sonnet-20250219"}`), "oauth-token", "oauth", "claude-3-7-sonnet-20250219", true, false) + require.NoError(t, err) + require.Equal(t, "Bearer oauth-token", req.Header.Get("authorization")) + require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta") +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Use a canceled context recorder to simulate client disconnect behavior. + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + cancel() + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`, + "", + `data: {"type":"message_delta","usage":{"output_tokens":5}}`, + "", + "data: [DONE]", + "", + }, "\n"))), + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 11, result.usage.InputTokens) + require.Equal(t, 5, result.usage.OutputTokens) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + body := []byte(`{"model":"claude-3-5-sonnet-latest","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":12,"output_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":2,"ephemeral_1h_input_tokens":3},"cached_tokens":4}}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-nonstream"}, + }, + Body: io.NopCloser(strings.NewReader(upstreamJSON)), + }, + } + svc := &GatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 5, result.Usage.CacheCreationInputTokens) + require.Equal(t, 4, result.Usage.CacheReadInputTokens) + require.Equal(t, upstreamJSON, rec.Body.String()) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenType(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + account := &Account{ + ID: 202, + Name: "anthropic-oauth", + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "oauth-token", + }, + } + svc := &GatewayService{} + + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now()) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "requires apikey token") +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequestError(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + upstream := &anthropicHTTPUpstreamRecorder{ + err: errors.New("dial tcp timeout"), + } + svc := &GatewayService{ + cfg: &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + }, + httpUpstream: upstream, + } + account := newAnthropicAPIKeyAccountForTest() + + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now()) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "upstream request failed") + require.Equal(t, http.StatusBadGateway, rec.Code) + rawBody, ok := c.Get(OpsUpstreamRequestBodyKey) + require.True(t, ok) + _, ok = rawBody.([]byte) + require.True(t, ok) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"x-request-id": []string{"rid-empty-body"}}, + Body: nil, + }, + } + svc := &GatewayService{ + cfg: &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + }, + httpUpstream: upstream, + } + + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now()) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "empty response") +} + +func TestExtractAnthropicSSEDataLine(t *testing.T) { + t.Run("valid data line with spaces", func(t *testing.T) { + data, ok := extractAnthropicSSEDataLine("data: {\"type\":\"message_start\"}") + require.True(t, ok) + require.Equal(t, `{"type":"message_start"}`, data) + }) + + t.Run("non data line", func(t *testing.T) { + data, ok := extractAnthropicSSEDataLine("event: message_start") + require.False(t, ok) + require.Empty(t, data) + }) +} + +func TestGatewayService_ParseSSEUsagePassthrough_MessageStartFallbacks(t *testing.T) { + svc := &GatewayService{} + usage := &ClaudeUsage{} + data := `{"type":"message_start","message":{"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cached_tokens":9,"cache_creation":{"ephemeral_5m_input_tokens":3,"ephemeral_1h_input_tokens":4}}}}` + + svc.parseSSEUsagePassthrough(data, usage) + + require.Equal(t, 12, usage.InputTokens) + require.Equal(t, 9, usage.CacheReadInputTokens, "应兼容 cached_tokens 字段") + require.Equal(t, 7, usage.CacheCreationInputTokens, "聚合字段为空时应从 5m/1h 明细回填") + require.Equal(t, 3, usage.CacheCreation5mTokens) + require.Equal(t, 4, usage.CacheCreation1hTokens) +} + +func TestGatewayService_ParseSSEUsagePassthrough_MessageDeltaSelectiveOverwrite(t *testing.T) { + svc := &GatewayService{} + usage := &ClaudeUsage{ + InputTokens: 10, + CacheCreation5mTokens: 2, + CacheCreation1hTokens: 6, + } + data := `{"type":"message_delta","usage":{"input_tokens":0,"output_tokens":5,"cache_creation_input_tokens":8,"cache_read_input_tokens":0,"cached_tokens":11,"cache_creation":{"ephemeral_5m_input_tokens":1,"ephemeral_1h_input_tokens":0}}}` + + svc.parseSSEUsagePassthrough(data, usage) + + require.Equal(t, 10, usage.InputTokens, "message_delta 中 0 值不应覆盖已有 input_tokens") + require.Equal(t, 5, usage.OutputTokens) + require.Equal(t, 8, usage.CacheCreationInputTokens) + require.Equal(t, 11, usage.CacheReadInputTokens, "cache_read_input_tokens 为空时应回退到 cached_tokens") + require.Equal(t, 1, usage.CacheCreation5mTokens) + require.Equal(t, 6, usage.CacheCreation1hTokens, "message_delta 中 0 值不应覆盖已有 1h 明细") +} + +func TestGatewayService_ParseSSEUsagePassthrough_NoopCases(t *testing.T) { + svc := &GatewayService{} + + usage := &ClaudeUsage{InputTokens: 3} + svc.parseSSEUsagePassthrough("", usage) + require.Equal(t, 3, usage.InputTokens) + + svc.parseSSEUsagePassthrough("[DONE]", usage) + require.Equal(t, 3, usage.InputTokens) + + svc.parseSSEUsagePassthrough("not-json", usage) + require.Equal(t, 3, usage.InputTokens) + + // nil usage 不应 panic + svc.parseSSEUsagePassthrough(`{"type":"message_start"}`, nil) +} + +func TestGatewayService_ParseSSEUsagePassthrough_FallbackFromUsageNode(t *testing.T) { + svc := &GatewayService{} + usage := &ClaudeUsage{} + data := `{"type":"content_block_delta","usage":{"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":2,"ephemeral_1h_input_tokens":1}}}` + + svc.parseSSEUsagePassthrough(data, usage) + + require.Equal(t, 6, usage.CacheReadInputTokens) + require.Equal(t, 3, usage.CacheCreationInputTokens) +} + +func TestParseClaudeUsageFromResponseBody(t *testing.T) { + t.Run("empty or missing usage", func(t *testing.T) { + got := parseClaudeUsageFromResponseBody(nil) + require.NotNil(t, got) + require.Equal(t, 0, got.InputTokens) + + got = parseClaudeUsageFromResponseBody([]byte(`{"id":"x"}`)) + require.NotNil(t, got) + require.Equal(t, 0, got.OutputTokens) + }) + + t.Run("parse all usage fields and fallback", func(t *testing.T) { + body := []byte(`{"usage":{"input_tokens":21,"output_tokens":34,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cached_tokens":13,"cache_creation":{"ephemeral_5m_input_tokens":5,"ephemeral_1h_input_tokens":8}}}`) + got := parseClaudeUsageFromResponseBody(body) + require.Equal(t, 21, got.InputTokens) + require.Equal(t, 34, got.OutputTokens) + require.Equal(t, 13, got.CacheReadInputTokens, "cache_read_input_tokens 为空时应回退 cached_tokens") + require.Equal(t, 13, got.CacheCreationInputTokens, "聚合字段为空时应由 5m/1h 回填") + require.Equal(t, 5, got.CacheCreation5mTokens) + require.Equal(t, 8, got.CacheCreation1hTokens) + }) + + t.Run("keep explicit aggregate values", func(t *testing.T) { + body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"cache_creation_input_tokens":9,"cache_read_input_tokens":7,"cached_tokens":99,"cache_creation":{"ephemeral_5m_input_tokens":4,"ephemeral_1h_input_tokens":5}}}`) + got := parseClaudeUsageFromResponseBody(body) + require.Equal(t, 9, got.CacheCreationInputTokens, "已显式提供聚合字段时不应被明细覆盖") + require.Equal(t, 7, got.CacheReadInputTokens, "已显式提供 cache_read_input_tokens 时不应回退 cached_tokens") + }) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingErrTooLong(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: 32, + }, + }, + } + + // Scanner 初始缓冲为 64KB,构造更长单行触发 bufio.ErrTooLong。 + longLine := "data: " + strings.Repeat("x", 80*1024) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(longLine)), + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 2}, time.Now(), "claude-3-7-sonnet-20250219") + require.Error(t, err) + require.ErrorIs(t, err, bufio.ErrTooLong) + require.NotNil(t, result) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingDataIntervalTimeout(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 1, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: pr, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 5}, time.Now(), "claude-3-7-sonnet-20250219") + _ = pw.Close() + _ = pr.Close() + + require.Error(t, err) + require.Contains(t, err.Error(), "stream data interval timeout") + require.NotNil(t, result) + require.False(t, result.clientDisconnect) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{ + err: io.ErrUnexpectedEOF, + }, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 6}, time.Now(), "claude-3-7-sonnet-20250219") + require.Error(t, err) + require.Contains(t, err.Error(), "stream read error") + require.NotNil(t, result) + require.False(t, result.clientDisconnect) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Writer = &failWriteResponseWriter{ResponseWriter: c.Writer} + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 1, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: pr, + } + + done := make(chan struct{}) + go func() { + defer close(done) + _, _ = pw.Write([]byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":9}}}` + "\n")) + // 保持上游连接静默,触发数据间隔超时分支。 + time.Sleep(1500 * time.Millisecond) + _ = pw.Close() + }() + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 7}, time.Now(), "claude-3-7-sonnet-20250219") + _ = pr.Close() + <-done + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.Equal(t, 9, result.usage.InputTokens) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{ + err: context.Canceled, + }, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Writer = &failWriteResponseWriter{ResponseWriter: c.Writer} + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{ + payload: []byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":8}}}` + "\n\n"), + err: io.ErrUnexpectedEOF, + }, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.Equal(t, 8, result.usage.InputTokens) +} diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go index dd58c183..21a1faa4 100644 --- a/backend/internal/service/gateway_beta_test.go +++ b/backend/internal/service/gateway_beta_test.go @@ -3,6 +3,8 @@ package service import ( "testing" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/stretchr/testify/require" ) @@ -21,3 +23,180 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) { ) require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got) } + +func TestStripBetaTokens(t *testing.T) { + tests := []struct { + name string + header string + tokens []string + want string + }{ + { + name: "single token in middle", + header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14", + tokens: []string{"context-1m-2025-08-07"}, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "single token at start", + header: "context-1m-2025-08-07,oauth-2025-04-20,interleaved-thinking-2025-05-14", + tokens: []string{"context-1m-2025-08-07"}, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "single token at end", + header: "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07", + tokens: []string{"context-1m-2025-08-07"}, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "token not present", + header: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + tokens: []string{"context-1m-2025-08-07"}, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "empty header", + header: "", + tokens: []string{"context-1m-2025-08-07"}, + want: "", + }, + { + name: "with spaces", + header: "oauth-2025-04-20, context-1m-2025-08-07 , interleaved-thinking-2025-05-14", + tokens: []string{"context-1m-2025-08-07"}, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "only token", + header: "context-1m-2025-08-07", + tokens: []string{"context-1m-2025-08-07"}, + want: "", + }, + { + name: "nil tokens", + header: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + tokens: nil, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "multiple tokens removed", + header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14,fast-mode-2026-02-01", + tokens: []string{"context-1m-2025-08-07", "fast-mode-2026-02-01"}, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "DroppedBetas removes both context-1m and fast-mode", + header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", + tokens: claude.DroppedBetas, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripBetaTokens(tt.header, tt.tokens) + require.Equal(t, tt.want, got) + }) + } +} + +func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) { + required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"} + incoming := "context-1m-2025-08-07,foo-beta,oauth-2025-04-20" + drop := map[string]struct{}{"context-1m-2025-08-07": {}} + + got := mergeAnthropicBetaDropping(required, incoming, drop) + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got) + require.NotContains(t, got, "context-1m-2025-08-07") +} + +func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) { + required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"} + incoming := "context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta,oauth-2025-04-20" + drop := droppedBetaSet() + + got := mergeAnthropicBetaDropping(required, incoming, drop) + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got) + require.NotContains(t, got, "context-1m-2025-08-07") + require.NotContains(t, got, "fast-mode-2026-02-01") +} + +func TestDroppedBetaSet(t *testing.T) { + // Base set contains DroppedBetas + base := droppedBetaSet() + require.Contains(t, base, claude.BetaContext1M) + require.Contains(t, base, claude.BetaFastMode) + require.Len(t, base, len(claude.DroppedBetas)) + + // With extra tokens + extended := droppedBetaSet(claude.BetaClaudeCode) + require.Contains(t, extended, claude.BetaContext1M) + require.Contains(t, extended, claude.BetaFastMode) + require.Contains(t, extended, claude.BetaClaudeCode) + require.Len(t, extended, len(claude.DroppedBetas)+1) +} + +func TestBuildBetaTokenSet(t *testing.T) { + got := buildBetaTokenSet([]string{"foo", "", "bar", "foo"}) + require.Len(t, got, 2) + require.Contains(t, got, "foo") + require.Contains(t, got, "bar") + require.NotContains(t, got, "") + + empty := buildBetaTokenSet(nil) + require.Empty(t, empty) +} + +func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) { + header := "oauth-2025-04-20,interleaved-thinking-2025-05-14" + got := stripBetaTokensWithSet(header, map[string]struct{}{}) + require.Equal(t, header, got) +} + +func TestIsCountTokensUnsupported404(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + want bool + }{ + { + name: "exact endpoint not found", + statusCode: 404, + body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`, + want: true, + }, + { + name: "contains count_tokens and not found", + statusCode: 404, + body: `{"error":{"message":"count_tokens route not found","type":"not_found_error"}}`, + want: true, + }, + { + name: "generic 404", + statusCode: 404, + body: `{"error":{"message":"resource not found","type":"not_found_error"}}`, + want: false, + }, + { + name: "404 with empty error message", + statusCode: 404, + body: `{"error":{"message":"","type":"not_found_error"}}`, + want: false, + }, + { + name: "non-404 status", + statusCode: 400, + body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"invalid_request_error"}}`, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isCountTokensUnsupported404(tt.statusCode, []byte(tt.body)) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/backend/internal/service/gateway_hotpath_optimization_test.go b/backend/internal/service/gateway_hotpath_optimization_test.go new file mode 100644 index 00000000..161c4ba4 --- /dev/null +++ b/backend/internal/service/gateway_hotpath_optimization_test.go @@ -0,0 +1,786 @@ +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + gocache "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/require" +) + +type userGroupRateRepoHotpathStub struct { + UserGroupRateRepository + + rate *float64 + err error + wait <-chan struct{} + calls atomic.Int64 +} + +func (s *userGroupRateRepoHotpathStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + s.calls.Add(1) + if s.wait != nil { + <-s.wait + } + if s.err != nil { + return nil, s.err + } + return s.rate, nil +} + +type usageLogWindowBatchRepoStub struct { + UsageLogRepository + + batchResult map[int64]*usagestats.AccountStats + batchErr error + batchCalls atomic.Int64 + + singleResult map[int64]*usagestats.AccountStats + singleErr error + singleCalls atomic.Int64 +} + +func (s *usageLogWindowBatchRepoStub) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) { + s.batchCalls.Add(1) + if s.batchErr != nil { + return nil, s.batchErr + } + out := make(map[int64]*usagestats.AccountStats, len(accountIDs)) + for _, id := range accountIDs { + if stats, ok := s.batchResult[id]; ok { + out[id] = stats + } + } + return out, nil +} + +func (s *usageLogWindowBatchRepoStub) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { + s.singleCalls.Add(1) + if s.singleErr != nil { + return nil, s.singleErr + } + if stats, ok := s.singleResult[accountID]; ok { + return stats, nil + } + return &usagestats.AccountStats{}, nil +} + +type sessionLimitCacheHotpathStub struct { + SessionLimitCache + + batchData map[int64]float64 + batchErr error + + setData map[int64]float64 + setErr error +} + +func (s *sessionLimitCacheHotpathStub) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) { + if s.batchErr != nil { + return nil, s.batchErr + } + out := make(map[int64]float64, len(accountIDs)) + for _, id := range accountIDs { + if v, ok := s.batchData[id]; ok { + out[id] = v + } + } + return out, nil +} + +func (s *sessionLimitCacheHotpathStub) SetWindowCost(ctx context.Context, accountID int64, cost float64) error { + if s.setErr != nil { + return s.setErr + } + if s.setData == nil { + s.setData = make(map[int64]float64) + } + s.setData[accountID] = cost + return nil +} + +type modelsListAccountRepoStub struct { + AccountRepository + + byGroup map[int64][]Account + all []Account + err error + + listByGroupCalls atomic.Int64 + listAllCalls atomic.Int64 +} + +type stickyGatewayCacheHotpathStub struct { + GatewayCache + + stickyID int64 + getCalls atomic.Int64 +} + +func (s *stickyGatewayCacheHotpathStub) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { + s.getCalls.Add(1) + if s.stickyID > 0 { + return s.stickyID, nil + } + return 0, errors.New("not found") +} + +func (s *stickyGatewayCacheHotpathStub) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { + return nil +} + +func (s *stickyGatewayCacheHotpathStub) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { + return nil +} + +func (s *stickyGatewayCacheHotpathStub) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + return nil +} + +func (s *modelsListAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) { + s.listByGroupCalls.Add(1) + if s.err != nil { + return nil, s.err + } + accounts, ok := s.byGroup[groupID] + if !ok { + return nil, nil + } + out := make([]Account, len(accounts)) + copy(out, accounts) + return out, nil +} + +func (s *modelsListAccountRepoStub) ListSchedulable(ctx context.Context) ([]Account, error) { + s.listAllCalls.Add(1) + if s.err != nil { + return nil, s.err + } + out := make([]Account, len(s.all)) + copy(out, s.all) + return out, nil +} + +func resetGatewayHotpathStatsForTest() { + windowCostPrefetchCacheHitTotal.Store(0) + windowCostPrefetchCacheMissTotal.Store(0) + windowCostPrefetchBatchSQLTotal.Store(0) + windowCostPrefetchFallbackTotal.Store(0) + windowCostPrefetchErrorTotal.Store(0) + + userGroupRateCacheHitTotal.Store(0) + userGroupRateCacheMissTotal.Store(0) + userGroupRateCacheLoadTotal.Store(0) + userGroupRateCacheSFSharedTotal.Store(0) + userGroupRateCacheFallbackTotal.Store(0) + + modelsListCacheHitTotal.Store(0) + modelsListCacheMissTotal.Store(0) + modelsListCacheStoreTotal.Store(0) +} + +func TestGetUserGroupRateMultiplier_UsesCacheAndSingleflight(t *testing.T) { + resetGatewayHotpathStatsForTest() + + rate := 1.7 + unblock := make(chan struct{}) + repo := &userGroupRateRepoHotpathStub{ + rate: &rate, + wait: unblock, + } + svc := &GatewayService{ + userGroupRateRepo: repo, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + UserGroupRateCacheTTLSeconds: 30, + }, + }, + } + + const concurrent = 12 + results := make([]float64, concurrent) + start := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(concurrent) + for i := 0; i < concurrent; i++ { + go func(idx int) { + defer wg.Done() + <-start + results[idx] = svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2) + }(i) + } + + close(start) + time.Sleep(20 * time.Millisecond) + close(unblock) + wg.Wait() + + for _, got := range results { + require.Equal(t, rate, got) + } + require.Equal(t, int64(1), repo.calls.Load()) + + // 再次读取应命中缓存,不再回源。 + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2) + require.Equal(t, rate, got) + require.Equal(t, int64(1), repo.calls.Load()) + + hit, miss, load, sfShared, fallback := GatewayUserGroupRateCacheStats() + require.GreaterOrEqual(t, hit, int64(1)) + require.Equal(t, int64(12), miss) + require.Equal(t, int64(1), load) + require.GreaterOrEqual(t, sfShared, int64(1)) + require.Equal(t, int64(0), fallback) +} + +func TestGetUserGroupRateMultiplier_FallbackOnRepoError(t *testing.T) { + resetGatewayHotpathStatsForTest() + + repo := &userGroupRateRepoHotpathStub{ + err: errors.New("db down"), + } + svc := &GatewayService{ + userGroupRateRepo: repo, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + UserGroupRateCacheTTLSeconds: 30, + }, + }, + } + + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.25) + require.Equal(t, 1.25, got) + require.Equal(t, int64(1), repo.calls.Load()) + + _, _, _, _, fallback := GatewayUserGroupRateCacheStats() + require.Equal(t, int64(1), fallback) +} + +func TestGetUserGroupRateMultiplier_CacheHitAndNilRepo(t *testing.T) { + resetGatewayHotpathStatsForTest() + + repo := &userGroupRateRepoHotpathStub{ + err: errors.New("should not be called"), + } + svc := &GatewayService{ + userGroupRateRepo: repo, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + } + key := "101:202" + svc.userGroupRateCache.Set(key, 2.3, time.Minute) + + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.1) + require.Equal(t, 2.3, got) + + hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats() + require.Equal(t, int64(1), hit) + require.Equal(t, int64(0), miss) + require.Equal(t, int64(0), load) + require.Equal(t, int64(0), fallback) + require.Equal(t, int64(0), repo.calls.Load()) + + // 无 repo 时直接返回分组默认倍率 + svc2 := &GatewayService{ + userGroupRateCache: gocache.New(time.Minute, time.Minute), + } + svc2.userGroupRateCache.Set(key, 1.9, time.Minute) + require.Equal(t, 1.9, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4)) + require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 0, 202, 1.4)) + svc2.userGroupRateCache.Delete(key) + require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4)) +} + +func TestWithWindowCostPrefetch_BatchReadAndContextReuse(t *testing.T) { + resetGatewayHotpathStatsForTest() + + windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour) + windowEnd := windowStart.Add(5 * time.Hour) + accounts := []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + { + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + { + ID: 3, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{"window_cost_limit": 100.0}, + }, + } + + cache := &sessionLimitCacheHotpathStub{ + batchData: map[int64]float64{ + 1: 11.0, + }, + } + repo := &usageLogWindowBatchRepoStub{ + batchResult: map[int64]*usagestats.AccountStats{ + 2: {StandardCost: 22.0}, + }, + } + svc := &GatewayService{ + sessionLimitCache: cache, + usageLogRepo: repo, + } + + outCtx := svc.withWindowCostPrefetch(context.Background(), accounts) + require.NotNil(t, outCtx) + + cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1) + require.True(t, ok1) + require.Equal(t, 11.0, cost1) + + cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2) + require.True(t, ok2) + require.Equal(t, 22.0, cost2) + + _, ok3 := windowCostFromPrefetchContext(outCtx, 3) + require.False(t, ok3) + + require.Equal(t, int64(1), repo.batchCalls.Load()) + require.Equal(t, 22.0, cache.setData[2]) + + hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats() + require.Equal(t, int64(1), hit) + require.Equal(t, int64(1), miss) + require.Equal(t, int64(1), batchSQL) + require.Equal(t, int64(0), fallback) + require.Equal(t, int64(0), errCount) +} + +func TestWithWindowCostPrefetch_AllHitNoSQL(t *testing.T) { + resetGatewayHotpathStatsForTest() + + windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour) + windowEnd := windowStart.Add(5 * time.Hour) + accounts := []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + { + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + } + + cache := &sessionLimitCacheHotpathStub{ + batchData: map[int64]float64{ + 1: 11.0, + 2: 22.0, + }, + } + repo := &usageLogWindowBatchRepoStub{} + svc := &GatewayService{ + sessionLimitCache: cache, + usageLogRepo: repo, + } + + outCtx := svc.withWindowCostPrefetch(context.Background(), accounts) + cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1) + cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2) + require.True(t, ok1) + require.True(t, ok2) + require.Equal(t, 11.0, cost1) + require.Equal(t, 22.0, cost2) + require.Equal(t, int64(0), repo.batchCalls.Load()) + require.Equal(t, int64(0), repo.singleCalls.Load()) + + hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats() + require.Equal(t, int64(2), hit) + require.Equal(t, int64(0), miss) + require.Equal(t, int64(0), batchSQL) + require.Equal(t, int64(0), fallback) + require.Equal(t, int64(0), errCount) +} + +func TestWithWindowCostPrefetch_BatchErrorFallbackSingleQuery(t *testing.T) { + resetGatewayHotpathStatsForTest() + + windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour) + windowEnd := windowStart.Add(5 * time.Hour) + accounts := []Account{ + { + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + } + + cache := &sessionLimitCacheHotpathStub{} + repo := &usageLogWindowBatchRepoStub{ + batchErr: errors.New("batch failed"), + singleResult: map[int64]*usagestats.AccountStats{ + 2: {StandardCost: 33.0}, + }, + } + svc := &GatewayService{ + sessionLimitCache: cache, + usageLogRepo: repo, + } + + outCtx := svc.withWindowCostPrefetch(context.Background(), accounts) + cost, ok := windowCostFromPrefetchContext(outCtx, 2) + require.True(t, ok) + require.Equal(t, 33.0, cost) + require.Equal(t, int64(1), repo.batchCalls.Load()) + require.Equal(t, int64(1), repo.singleCalls.Load()) + + _, _, _, fallback, errCount := GatewayWindowCostPrefetchStats() + require.Equal(t, int64(1), fallback) + require.Equal(t, int64(1), errCount) +} + +func TestGetAvailableModels_UsesShortCacheAndSupportsInvalidation(t *testing.T) { + resetGatewayHotpathStatsForTest() + + groupID := int64(9) + repo := &modelsListAccountRepoStub{ + byGroup: map[int64][]Account{ + groupID: { + { + ID: 1, + Platform: PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "claude-3-5-sonnet", + "claude-3-5-haiku": "claude-3-5-haiku", + }, + }, + }, + { + ID: 2, + Platform: PlatformGemini, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-2.5-pro": "gemini-2.5-pro", + }, + }, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + models1 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models1) + require.Equal(t, int64(1), repo.listByGroupCalls.Load()) + + // TTL 内再次请求应命中缓存,不回源。 + models2 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, models1, models2) + require.Equal(t, int64(1), repo.listByGroupCalls.Load()) + + // 更新仓储数据,但缓存未失效前应继续返回旧值。 + repo.byGroup[groupID] = []Account{ + { + ID: 3, + Platform: PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-7-sonnet": "claude-3-7-sonnet", + }, + }, + }, + } + models3 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models3) + require.Equal(t, int64(1), repo.listByGroupCalls.Load()) + + svc.InvalidateAvailableModelsCache(&groupID, PlatformAnthropic) + models4 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, []string{"claude-3-7-sonnet"}, models4) + require.Equal(t, int64(2), repo.listByGroupCalls.Load()) + + hit, miss, store := GatewayModelsListCacheStats() + require.Equal(t, int64(2), hit) + require.Equal(t, int64(2), miss) + require.Equal(t, int64(2), store) +} + +func TestGetAvailableModels_ErrorAndGlobalListBranches(t *testing.T) { + resetGatewayHotpathStatsForTest() + + errRepo := &modelsListAccountRepoStub{ + err: errors.New("db error"), + } + svcErr := &GatewayService{ + accountRepo: errRepo, + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + require.Nil(t, svcErr.GetAvailableModels(context.Background(), nil, "")) + + okRepo := &modelsListAccountRepoStub{ + all: []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "claude-3-5-sonnet", + }, + }, + }, + { + ID: 2, + Platform: PlatformGemini, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-2.5-pro": "gemini-2.5-pro", + }, + }, + }, + }, + } + svcOK := &GatewayService{ + accountRepo: okRepo, + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + models := svcOK.GetAvailableModels(context.Background(), nil, "") + require.Equal(t, []string{"claude-3-5-sonnet", "gemini-2.5-pro"}, models) + require.Equal(t, int64(1), okRepo.listAllCalls.Load()) +} + +func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) { + t.Run("resolve_user_group_rate_cache_ttl", func(t *testing.T) { + require.Equal(t, defaultUserGroupRateCacheTTL, resolveUserGroupRateCacheTTL(nil)) + + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + UserGroupRateCacheTTLSeconds: 45, + }, + } + require.Equal(t, 45*time.Second, resolveUserGroupRateCacheTTL(cfg)) + }) + + t.Run("resolve_models_list_cache_ttl", func(t *testing.T) { + require.Equal(t, defaultModelsListCacheTTL, resolveModelsListCacheTTL(nil)) + + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + ModelsListCacheTTLSeconds: 20, + }, + } + require.Equal(t, 20*time.Second, resolveModelsListCacheTTL(cfg)) + }) + + t.Run("prefetched_sticky_account_id_from_context", func(t *testing.T) { + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO(), nil)) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background(), nil)) + + ctx := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(123)) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0)) + require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx, nil)) + + groupID := int64(9) + ctx2 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, 456) + ctx2 = context.WithValue(ctx2, ctxkey.PrefetchedStickyGroupID, groupID) + require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2, &groupID)) + + ctx3 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, "invalid") + ctx3 = context.WithValue(ctx3, ctxkey.PrefetchedStickyGroupID, groupID) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3, &groupID)) + + ctx4 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(789)) + ctx4 = context.WithValue(ctx4, ctxkey.PrefetchedStickyGroupID, int64(10)) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx4, &groupID)) + }) + + t.Run("window_cost_from_prefetch_context", func(t *testing.T) { + require.Equal(t, false, func() bool { + _, ok := windowCostFromPrefetchContext(context.TODO(), 0) + return ok + }()) + require.Equal(t, false, func() bool { + _, ok := windowCostFromPrefetchContext(context.Background(), 1) + return ok + }()) + + ctx := context.WithValue(context.Background(), windowCostPrefetchContextKey, map[int64]float64{ + 9: 12.34, + }) + cost, ok := windowCostFromPrefetchContext(ctx, 9) + require.True(t, ok) + require.Equal(t, 12.34, cost) + }) +} + +func TestInvalidateAvailableModelsCache_ByDimensions(t *testing.T) { + svc := &GatewayService{ + modelsListCache: gocache.New(time.Minute, time.Minute), + } + group9 := int64(9) + group10 := int64(10) + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute) + svc.modelsListCache.Set("invalid-key", []string{"d"}, time.Minute) + + t.Run("invalidate_group_and_platform", func(t *testing.T) { + svc.InvalidateAvailableModelsCache(&group9, PlatformAnthropic) + _, found := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic)) + require.False(t, found) + _, stillFound := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini)) + require.True(t, stillFound) + }) + + t.Run("invalidate_group_only", func(t *testing.T) { + svc.InvalidateAvailableModelsCache(&group9, "") + _, foundA := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic)) + _, foundB := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini)) + require.False(t, foundA) + require.False(t, foundB) + _, foundOtherGroup := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic)) + require.True(t, foundOtherGroup) + }) + + t.Run("invalidate_platform_only", func(t *testing.T) { + // 重建数据后仅按 platform 失效 + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute) + + svc.InvalidateAvailableModelsCache(nil, PlatformAnthropic) + _, found9Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic)) + _, found10Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic)) + _, found9Gemini := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini)) + require.False(t, found9Anthropic) + require.False(t, found10Anthropic) + require.True(t, found9Gemini) + }) +} + +func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) { + now := time.Now().Add(-time.Minute) + account := Account{ + ID: 88, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 4, + Priority: 1, + LastUsedAt: &now, + } + + repo := stubOpenAIAccountRepo{accounts: []Account{account}} + concurrency := NewConcurrencyService(stubConcurrencyCache{}) + + cfg := &config.Config{ + RunMode: config.RunModeStandard, + Gateway: config.GatewayConfig{ + Scheduling: config.GatewaySchedulingConfig{ + LoadBatchEnabled: true, + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: time.Second, + FallbackWaitTimeout: time.Second, + FallbackMaxWaiting: 10, + }, + }, + } + + baseCtx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAnthropic) + + t.Run("without_prefetch_reads_cache_once", func(t *testing.T) { + cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID} + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: concurrency, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, account.ID, result.Account.ID) + require.Equal(t, int64(1), cache.getCalls.Load()) + }) + + t.Run("with_prefetch_skips_cache_read", func(t *testing.T) { + cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID} + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: concurrency, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0)) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, account.ID, result.Account.ID) + require.Equal(t, int64(0), cache.getCalls.Load()) + }) + + t.Run("with_prefetch_group_mismatch_reads_cache", func(t *testing.T) { + cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID} + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: concurrency, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999)) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77)) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, account.ID, result.Account.ID) + require.Equal(t, int64(1), cache.getCalls.Load()) + }) +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index b3e60c21..067a0e08 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -77,6 +77,14 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) { return nil, nil } + +func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { + return nil, nil +} + +func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { + return nil, nil +} func (m *mockAccountRepoForPlatform) Update(ctx context.Context, account *Account) error { return nil } @@ -84,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { @@ -142,9 +150,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } -func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { - return nil -} func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { return nil } @@ -216,22 +221,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context return nil } -func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { - return 0, nil -} - -func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) { - return nil, nil -} - -func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - type mockGroupRepoForGateway struct { groups map[int64]*Group getByIDCalls int @@ -290,6 +279,10 @@ func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, g return nil, nil } +func (m *mockGroupRepoForGateway) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { + return nil +} + func ptr[T any](v T) *T { return &v } @@ -902,6 +895,55 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth(t *t require.Equal(t, int64(2), acc.ID) } +func TestGatewayService_SelectAccountForModelWithPlatform_GeminiAPIKeyModelMappingFilter(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + { + ID: 1, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Priority: 1, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}}, + }, + { + ID: 2, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Priority: 2, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-flash": "gemini-2.5-flash"}}, + }, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-flash", nil, PlatformGemini) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "应过滤不支持请求模型的 APIKey 账号") + + acc, err = svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-3-pro-preview", nil, PlatformGemini) + require.Error(t, err) + require.Nil(t, acc) + require.Contains(t, err.Error(), "supporting model") +} + func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) { ctx := context.Background() groupID := int64(50) @@ -1077,6 +1119,36 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) { model: "claude-3-5-sonnet-20241022", expected: true, }, + { + name: "Gemini平台-无映射配置-支持所有模型", + account: &Account{Platform: PlatformGemini, Type: AccountTypeAPIKey}, + model: "gemini-2.5-flash", + expected: true, + }, + { + name: "Gemini平台-有映射配置-只支持配置的模型", + account: &Account{ + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}, + }, + }, + model: "gemini-2.5-flash", + expected: false, + }, + { + name: "Gemini平台-有映射配置-支持配置的模型", + account: &Account{ + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}, + }, + }, + model: "gemini-2.5-pro", + expected: true, + }, } for _, tt := range tests { @@ -1820,6 +1892,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun return 0, nil } +func (m *mockConcurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + result[accountID] = 0 + } + return result, nil +} + func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { return true, nil } diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 0ecd18aa..b546fe85 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -5,10 +5,39 @@ import ( "encoding/json" "fmt" "math" + "unsafe" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) +var ( + // 这些字节模式用于 fast-path 判断,避免每次 []byte("...") 产生临时分配。 + patternTypeThinking = []byte(`"type":"thinking"`) + patternTypeThinkingSpaced = []byte(`"type": "thinking"`) + patternTypeRedactedThinking = []byte(`"type":"redacted_thinking"`) + patternTypeRedactedSpaced = []byte(`"type": "redacted_thinking"`) + + patternThinkingField = []byte(`"thinking":`) + patternThinkingFieldSpaced = []byte(`"thinking" :`) + + patternEmptyContent = []byte(`"content":[]`) + patternEmptyContentSpaced = []byte(`"content": []`) + patternEmptyContentSp1 = []byte(`"content" : []`) + patternEmptyContentSp2 = []byte(`"content" :[]`) +) + +// SessionContext 粘性会话上下文,用于区分不同来源的请求。 +// 仅在 GenerateSessionHash 第 3 级 fallback(消息内容 hash)时混入, +// 避免不同用户发送相同消息产生相同 hash 导致账号集中。 +type SessionContext struct { + ClientIP string + UserAgent string + APIKeyID int64 +} + // ParsedRequest 保存网关请求的预解析结果 // // 性能优化说明: @@ -22,121 +51,156 @@ import ( // 2. 将解析结果 ParsedRequest 传递给 Service 层 // 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销 type ParsedRequest struct { - Body []byte // 原始请求体(保留用于转发) - Model string // 请求的模型名称 - Stream bool // 是否为流式请求 - MetadataUserID string // metadata.user_id(用于会话亲和) - System any // system 字段内容 - Messages []any // messages 数组 - HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) - ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) - MaxTokens int // max_tokens 值(用于探测请求拦截) + Body []byte // 原始请求体(保留用于转发) + Model string // 请求的模型名称 + Stream bool // 是否为流式请求 + MetadataUserID string // metadata.user_id(用于会话亲和) + System any // system 字段内容 + Messages []any // messages 数组 + HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) + ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) + MaxTokens int // max_tokens 值(用于探测请求拦截) + SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变) + + // OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁) + // 流式请求在收到 2xx 响应头后调用,避免持锁等流完成 + OnUpstreamAccepted func() } -// ParseGatewayRequest 解析网关请求体并返回结构化结果 -// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal -func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { - return nil, err +// ParseGatewayRequest 解析网关请求体并返回结构化结果。 +// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini), +// 不同协议使用不同的 system/messages 字段名。 +func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { + // 保持与旧实现一致:请求体必须是合法 JSON。 + // 注意:gjson.GetBytes 对非法 JSON 不会报错,因此需要显式校验。 + if !gjson.ValidBytes(body) { + return nil, fmt.Errorf("invalid json") } + // 性能: + // - gjson.GetBytes 会把匹配的 Raw/Str 安全复制成 string(对于巨大 messages 会产生额外拷贝)。 + // - 这里将 body 通过 unsafe 零拷贝视为 string,仅在本函数内使用,且 body 不会被修改。 + jsonStr := *(*string)(unsafe.Pointer(&body)) + parsed := &ParsedRequest{ Body: body, } - if rawModel, exists := req["model"]; exists { - model, ok := rawModel.(string) - if !ok { + // --- gjson 提取简单字段(避免完整 Unmarshal) --- + + // model: 需要严格类型校验,非 string 返回错误 + modelResult := gjson.Get(jsonStr, "model") + if modelResult.Exists() { + if modelResult.Type != gjson.String { return nil, fmt.Errorf("invalid model field type") } - parsed.Model = model + parsed.Model = modelResult.String() } - if rawStream, exists := req["stream"]; exists { - stream, ok := rawStream.(bool) - if !ok { + + // stream: 需要严格类型校验,非 bool 返回错误 + streamResult := gjson.Get(jsonStr, "stream") + if streamResult.Exists() { + if streamResult.Type != gjson.True && streamResult.Type != gjson.False { return nil, fmt.Errorf("invalid stream field type") } - parsed.Stream = stream - } - if metadata, ok := req["metadata"].(map[string]any); ok { - if userID, ok := metadata["user_id"].(string); ok { - parsed.MetadataUserID = userID - } - } - // system 字段只要存在就视为显式提供(即使为 null), - // 以避免客户端传 null 时被默认 system 误注入。 - if system, ok := req["system"]; ok { - parsed.HasSystem = true - parsed.System = system - } - if messages, ok := req["messages"].([]any); ok { - parsed.Messages = messages + parsed.Stream = streamResult.Bool() } - // thinking: {type: "enabled"} - if rawThinking, ok := req["thinking"].(map[string]any); ok { - if t, ok := rawThinking["type"].(string); ok && t == "enabled" { - parsed.ThinkingEnabled = true + // metadata.user_id: 直接路径提取,不需要严格类型校验 + parsed.MetadataUserID = gjson.Get(jsonStr, "metadata.user_id").String() + + // thinking.type: enabled/adaptive 都视为开启 + thinkingType := gjson.Get(jsonStr, "thinking.type").String() + if thinkingType == "enabled" || thinkingType == "adaptive" { + parsed.ThinkingEnabled = true + } + + // max_tokens: 仅接受整数值 + maxTokensResult := gjson.Get(jsonStr, "max_tokens") + if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number { + f := maxTokensResult.Float() + if !math.IsNaN(f) && !math.IsInf(f, 0) && f == math.Trunc(f) && + f <= float64(math.MaxInt) && f >= float64(math.MinInt) { + parsed.MaxTokens = int(f) } } - // max_tokens - if rawMaxTokens, exists := req["max_tokens"]; exists { - if maxTokens, ok := parseIntegralNumber(rawMaxTokens); ok { - parsed.MaxTokens = maxTokens + // --- system/messages 提取 --- + // 避免把整个 body Unmarshal 到 map(会产生大量 map/接口分配)。 + // 使用 gjson 抽取目标字段的 Raw,再对该子树进行 Unmarshal。 + + switch protocol { + case domain.PlatformGemini: + // Gemini 原生格式: systemInstruction.parts / contents + if sysParts := gjson.Get(jsonStr, "systemInstruction.parts"); sysParts.Exists() && sysParts.IsArray() { + var parts []any + if err := json.Unmarshal(sliceRawFromBody(body, sysParts), &parts); err != nil { + return nil, err + } + parsed.System = parts + } + + if contents := gjson.Get(jsonStr, "contents"); contents.Exists() && contents.IsArray() { + var msgs []any + if err := json.Unmarshal(sliceRawFromBody(body, contents), &msgs); err != nil { + return nil, err + } + parsed.Messages = msgs + } + default: + // Anthropic / OpenAI 格式: system / messages + // system 字段只要存在就视为显式提供(即使为 null), + // 以避免客户端传 null 时被默认 system 误注入。 + if sys := gjson.Get(jsonStr, "system"); sys.Exists() { + parsed.HasSystem = true + switch sys.Type { + case gjson.Null: + parsed.System = nil + case gjson.String: + // 与 encoding/json 的 Unmarshal 行为一致:返回解码后的字符串。 + parsed.System = sys.String() + default: + var system any + if err := json.Unmarshal(sliceRawFromBody(body, sys), &system); err != nil { + return nil, err + } + parsed.System = system + } + } + + if msgs := gjson.Get(jsonStr, "messages"); msgs.Exists() && msgs.IsArray() { + var messages []any + if err := json.Unmarshal(sliceRawFromBody(body, msgs), &messages); err != nil { + return nil, err + } + parsed.Messages = messages } } return parsed, nil } -// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。 -// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。 -func parseIntegralNumber(raw any) (int, bool) { - switch v := raw.(type) { - case float64: - if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) { - return 0, false +// sliceRawFromBody 返回 Result.Raw 对应的原始字节切片。 +// 优先使用 Result.Index 直接从 body 切片,避免对大字段(如 messages)产生额外拷贝。 +// 当 Index 不可用时,退化为复制(理论上极少发生)。 +func sliceRawFromBody(body []byte, r gjson.Result) []byte { + if r.Index > 0 { + end := r.Index + len(r.Raw) + if end <= len(body) { + return body[r.Index:end] } - if v > float64(math.MaxInt) || v < float64(math.MinInt) { - return 0, false - } - return int(v), true - case int: - return v, true - case int8: - return int(v), true - case int16: - return int(v), true - case int32: - return int(v), true - case int64: - if v > int64(math.MaxInt) || v < int64(math.MinInt) { - return 0, false - } - return int(v), true - case json.Number: - i64, err := v.Int64() - if err != nil { - return 0, false - } - if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) { - return 0, false - } - return int(i64), true - default: - return 0, false } + // fallback: 不影响正确性,但会产生一次拷贝 + return []byte(r.Raw) } // FilterThinkingBlocks removes thinking blocks from request body // Returns filtered body or original body if filtering fails (fail-safe) // This prevents 400 errors from invalid thinking block signatures // -// Strategy: -// - When thinking.type != "enabled": Remove all thinking blocks -// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures +// 策略: +// - 当 thinking.type 不是 "enabled"/"adaptive":移除所有 thinking 相关块 +// - 当 thinking.type 是 "enabled"/"adaptive":仅移除缺失/无效 signature 的 thinking 块(避免 400) // (blocks with missing/empty/dummy signatures that would cause 400 errors) func FilterThinkingBlocks(body []byte) []byte { return filterThinkingBlocksInternal(body, false) @@ -157,49 +221,63 @@ func FilterThinkingBlocks(body []byte) []byte { // - Remove `redacted_thinking` blocks (cannot be converted to text). // - Ensure no message ends up with empty content. func FilterThinkingBlocksForRetry(body []byte) []byte { - hasThinkingContent := bytes.Contains(body, []byte(`"type":"thinking"`)) || - bytes.Contains(body, []byte(`"type": "thinking"`)) || - bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) || - bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) || - bytes.Contains(body, []byte(`"thinking":`)) || - bytes.Contains(body, []byte(`"thinking" :`)) + hasThinkingContent := bytes.Contains(body, patternTypeThinking) || + bytes.Contains(body, patternTypeThinkingSpaced) || + bytes.Contains(body, patternTypeRedactedThinking) || + bytes.Contains(body, patternTypeRedactedSpaced) || + bytes.Contains(body, patternThinkingField) || + bytes.Contains(body, patternThinkingFieldSpaced) // Also check for empty content arrays that need fixing. // Note: This is a heuristic check; the actual empty content handling is done below. - hasEmptyContent := bytes.Contains(body, []byte(`"content":[]`)) || - bytes.Contains(body, []byte(`"content": []`)) || - bytes.Contains(body, []byte(`"content" : []`)) || - bytes.Contains(body, []byte(`"content" :[]`)) + hasEmptyContent := bytes.Contains(body, patternEmptyContent) || + bytes.Contains(body, patternEmptyContentSpaced) || + bytes.Contains(body, patternEmptyContentSp1) || + bytes.Contains(body, patternEmptyContentSp2) // Fast path: nothing to process if !hasThinkingContent && !hasEmptyContent { return body } - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { + // 尽量避免把整个 body Unmarshal 成 map(会产生大量 map/接口分配)。 + // 这里先用 gjson 把 messages 子树摘出来,后续只对 messages 做 Unmarshal/Marshal。 + jsonStr := *(*string)(unsafe.Pointer(&body)) + msgsRes := gjson.Get(jsonStr, "messages") + if !msgsRes.Exists() || !msgsRes.IsArray() { + return body + } + + // Fast path:只需要删除顶层 thinking,不需要改 messages。 + // 注意:patternThinkingField 可能来自嵌套字段(如 tool_use.input.thinking),因此必须用 gjson 判断顶层字段是否存在。 + containsThinkingBlocks := bytes.Contains(body, patternTypeThinking) || + bytes.Contains(body, patternTypeThinkingSpaced) || + bytes.Contains(body, patternTypeRedactedThinking) || + bytes.Contains(body, patternTypeRedactedSpaced) || + bytes.Contains(body, patternThinkingFieldSpaced) + if !hasEmptyContent && !containsThinkingBlocks { + if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() { + if out, err := sjson.DeleteBytes(body, "thinking"); err == nil { + return out + } + return body + } + return body + } + + var messages []any + if err := json.Unmarshal(sliceRawFromBody(body, msgsRes), &messages); err != nil { return body } modified := false - messages, ok := req["messages"].([]any) - if !ok { - return body - } - // Disable top-level thinking mode for retry to avoid structural/signature constraints upstream. - if _, exists := req["thinking"]; exists { - delete(req, "thinking") - modified = true - } + deleteTopLevelThinking := gjson.Get(jsonStr, "thinking").Exists() - newMessages := make([]any, 0, len(messages)) - - for _, msg := range messages { - msgMap, ok := msg.(map[string]any) + for i := 0; i < len(messages); i++ { + msgMap, ok := messages[i].(map[string]any) if !ok { - newMessages = append(newMessages, msg) continue } @@ -207,17 +285,30 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { content, ok := msgMap["content"].([]any) if !ok { // String content or other format - keep as is - newMessages = append(newMessages, msg) continue } - newContent := make([]any, 0, len(content)) + // 延迟分配:只有检测到需要修改的块,才构建新 slice。 + var newContent []any modifiedThisMsg := false - for _, block := range content { + ensureNewContent := func(prefixLen int) { + if newContent != nil { + return + } + newContent = make([]any, 0, len(content)) + if prefixLen > 0 { + newContent = append(newContent, content[:prefixLen]...) + } + } + + for bi := 0; bi < len(content); bi++ { + block := content[bi] blockMap, ok := block.(map[string]any) if !ok { - newContent = append(newContent, block) + if newContent != nil { + newContent = append(newContent, block) + } continue } @@ -227,17 +318,15 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { switch blockType { case "thinking": modifiedThisMsg = true + ensureNewContent(bi) thinkingText, _ := blockMap["thinking"].(string) - if thinkingText == "" { - continue + if thinkingText != "" { + newContent = append(newContent, map[string]any{"type": "text", "text": thinkingText}) } - newContent = append(newContent, map[string]any{ - "type": "text", - "text": thinkingText, - }) continue case "redacted_thinking": modifiedThisMsg = true + ensureNewContent(bi) continue } @@ -245,6 +334,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { if blockType == "" { if rawThinking, hasThinking := blockMap["thinking"]; hasThinking { modifiedThisMsg = true + ensureNewContent(bi) switch v := rawThinking.(type) { case string: if v != "" { @@ -259,40 +349,64 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { } } - newContent = append(newContent, block) + if newContent != nil { + newContent = append(newContent, block) + } } // Handle empty content: either from filtering or originally empty + if newContent == nil { + if len(content) == 0 { + modified = true + placeholder := "(content removed)" + if role == "assistant" { + placeholder = "(assistant content removed)" + } + msgMap["content"] = []any{map[string]any{"type": "text", "text": placeholder}} + } + continue + } + if len(newContent) == 0 { modified = true placeholder := "(content removed)" if role == "assistant" { placeholder = "(assistant content removed)" } - newContent = append(newContent, map[string]any{ - "type": "text", - "text": placeholder, - }) - msgMap["content"] = newContent - } else if modifiedThisMsg { + msgMap["content"] = []any{map[string]any{"type": "text", "text": placeholder}} + continue + } + + if modifiedThisMsg { modified = true msgMap["content"] = newContent } - newMessages = append(newMessages, msgMap) } - if modified { - req["messages"] = newMessages - } else { + if !modified && !deleteTopLevelThinking { // Avoid rewriting JSON when no changes are needed. return body } - newBody, err := json.Marshal(req) - if err != nil { - return body + out := body + if deleteTopLevelThinking { + if b, err := sjson.DeleteBytes(out, "thinking"); err == nil { + out = b + } else { + return body + } } - return newBody + if modified { + msgsBytes, err := json.Marshal(messages) + if err != nil { + return body + } + out, err = sjson.SetRawBytes(out, "messages", msgsBytes) + if err != nil { + return body + } + } + return out } // FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate @@ -462,9 +576,9 @@ func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte { } // filterThinkingBlocksInternal removes invalid thinking blocks from request -// Strategy: -// - When thinking.type != "enabled": Remove all thinking blocks -// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures +// 策略: +// - 当 thinking.type 不是 "enabled"/"adaptive":移除所有 thinking 相关块 +// - 当 thinking.type 是 "enabled"/"adaptive":仅移除缺失/无效 signature 的 thinking 块 func filterThinkingBlocksInternal(body []byte, _ bool) []byte { // Fast path: if body doesn't contain "thinking", skip parsing if !bytes.Contains(body, []byte(`"type":"thinking"`)) && @@ -484,7 +598,7 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte { // Check if thinking is enabled thinkingEnabled := false if thinking, ok := req["thinking"].(map[string]any); ok { - if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" { + if thinkType, ok := thinking["type"].(string); ok && (thinkType == "enabled" || thinkType == "adaptive") { thinkingEnabled = true } } diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index 4e390b0a..2a9b4017 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -1,15 +1,20 @@ +//go:build unit + package service import ( "encoding/json" + "fmt" + "strings" "testing" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/stretchr/testify/require" ) func TestParseGatewayRequest(t *testing.T) { body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, "claude-3-7-sonnet", parsed.Model) require.True(t, parsed.Stream) @@ -22,7 +27,15 @@ func TestParseGatewayRequest(t *testing.T) { func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") + require.NoError(t, err) + require.Equal(t, "claude-sonnet-4-5", parsed.Model) + require.True(t, parsed.ThinkingEnabled) +} + +func TestParseGatewayRequest_ThinkingAdaptiveEnabled(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"adaptive"},"messages":[{"content":"hi"}]}`) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, "claude-sonnet-4-5", parsed.Model) require.True(t, parsed.ThinkingEnabled) @@ -30,21 +43,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { func TestParseGatewayRequest_MaxTokens(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, 1, parsed.MaxTokens) } func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, 0, parsed.MaxTokens) } func TestParseGatewayRequest_SystemNull(t *testing.T) { body := []byte(`{"model":"claude-3","system":null}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) // 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。 require.True(t, parsed.HasSystem) @@ -53,16 +66,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) { func TestParseGatewayRequest_InvalidModelType(t *testing.T) { body := []byte(`{"model":123}`) - _, err := ParseGatewayRequest(body) + _, err := ParseGatewayRequest(body, "") require.Error(t, err) } func TestParseGatewayRequest_InvalidStreamType(t *testing.T) { body := []byte(`{"stream":"true"}`) - _, err := ParseGatewayRequest(body) + _, err := ParseGatewayRequest(body, "") require.Error(t, err) } +// ============ Gemini 原生格式解析测试 ============ + +func TestParseGatewayRequest_GeminiContents(t *testing.T) { + body := []byte(`{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]}, + {"role": "model", "parts": [{"text": "Hi there"}]}, + {"role": "user", "parts": [{"text": "How are you?"}]} + ] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Len(t, parsed.Messages, 3, "should parse contents as Messages") + require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem") + require.Nil(t, parsed.System, "no systemInstruction means nil System") +} + +func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) { + body := []byte(`{ + "systemInstruction": { + "parts": [{"text": "You are a helpful assistant."}] + }, + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System") + parts, ok := parsed.System.([]any) + require.True(t, ok) + require.Len(t, parts, 1) + partMap, ok := parts[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "You are a helpful assistant.", partMap["text"]) + require.Len(t, parsed.Messages, 1) +} + +func TestParseGatewayRequest_GeminiWithModel(t *testing.T) { + body := []byte(`{ + "model": "gemini-2.5-pro", + "contents": [{"role": "user", "parts": [{"text": "test"}]}] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Equal(t, "gemini-2.5-pro", parsed.Model) + require.Len(t, parsed.Messages, 1) +} + +func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) { + // Gemini 格式下 system/messages 字段应被忽略 + body := []byte(`{ + "system": "should be ignored", + "messages": [{"role": "user", "content": "ignored"}], + "contents": [{"role": "user", "parts": [{"text": "real content"}]}] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field") + require.Nil(t, parsed.System, "no systemInstruction = nil System") + require.Len(t, parsed.Messages, 1, "should use contents, not messages") +} + +func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) { + body := []byte(`{"contents": []}`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Empty(t, parsed.Messages) +} + +func TestParseGatewayRequest_GeminiNoContents(t *testing.T) { + body := []byte(`{"model": "gemini-2.5-flash"}`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Nil(t, parsed.Messages) + require.Equal(t, "gemini-2.5-flash", parsed.Model) +} + +func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) { + // Anthropic 格式下 contents/systemInstruction 字段应被忽略 + body := []byte(`{ + "system": "real system", + "messages": [{"role": "user", "content": "real content"}], + "contents": [{"role": "user", "parts": [{"text": "ignored"}]}], + "systemInstruction": {"parts": [{"text": "ignored"}]} + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic) + require.NoError(t, err) + require.True(t, parsed.HasSystem) + require.Equal(t, "real system", parsed.System) + require.Len(t, parsed.Messages, 1) + msg, ok := parsed.Messages[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "real content", msg["content"]) +} + func TestFilterThinkingBlocks(t *testing.T) { containsThinkingBlock := func(body []byte) bool { var req map[string]any @@ -112,6 +221,16 @@ func TestFilterThinkingBlocks(t *testing.T) { input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`, shouldFilter: true, }, + { + name: "does not filter signed thinking blocks when thinking adaptive", + input: `{"thinking":{"type":"adaptive"},"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"ok","signature":"sig_real_123"},{"type":"text","text":"B"}]}]}`, + shouldFilter: false, + }, + { + name: "filters unsigned thinking blocks when thinking adaptive", + input: `{"thinking":{"type":"adaptive"},"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"internal","signature":""},{"type":"text","text":"B"}]}]}`, + shouldFilter: true, + }, { name: "handles no thinking blocks", input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`, @@ -319,3 +438,341 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) { require.Contains(t, content0["text"], "tool_use") require.Contains(t, content1["text"], "tool_result") } + +// ============ Group 7: ParseGatewayRequest 补充单元测试 ============ + +// Task 7.1 — 类型校验边界测试 +func TestParseGatewayRequest_TypeValidation(t *testing.T) { + tests := []struct { + name string + body string + wantErr bool + errSubstr string // 期望的错误信息子串(为空则不检查) + }{ + { + name: "model 为 int", + body: `{"model":123}`, + wantErr: true, + errSubstr: "invalid model field type", + }, + { + name: "model 为 array", + body: `{"model":[]}`, + wantErr: true, + errSubstr: "invalid model field type", + }, + { + name: "model 为 bool", + body: `{"model":true}`, + wantErr: true, + errSubstr: "invalid model field type", + }, + { + name: "model 为 null — gjson Null 类型触发类型校验错误", + body: `{"model":null}`, + wantErr: true, // gjson: Exists()=true, Type=Null != String → 返回错误 + errSubstr: "invalid model field type", + }, + { + name: "stream 为 string", + body: `{"stream":"true"}`, + wantErr: true, + errSubstr: "invalid stream field type", + }, + { + name: "stream 为 int", + body: `{"stream":1}`, + wantErr: true, + errSubstr: "invalid stream field type", + }, + { + name: "stream 为 null — gjson Null 类型触发类型校验错误", + body: `{"stream":null}`, + wantErr: true, // gjson: Exists()=true, Type=Null != True && != False → 返回错误 + errSubstr: "invalid stream field type", + }, + { + name: "model 为 object", + body: `{"model":{}}`, + wantErr: true, + errSubstr: "invalid model field type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseGatewayRequest([]byte(tt.body), "") + if tt.wantErr { + require.Error(t, err) + if tt.errSubstr != "" { + require.Contains(t, err.Error(), tt.errSubstr) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// Task 7.2 — 可选字段缺失测试 +func TestParseGatewayRequest_OptionalFieldsMissing(t *testing.T) { + tests := []struct { + name string + body string + wantModel string + wantStream bool + wantMetadataUID string + wantHasSystem bool + wantThinking bool + wantMaxTokens int + wantMessagesNil bool + wantMessagesLen int + }{ + { + name: "完全空 JSON — 所有字段零值", + body: `{}`, + wantModel: "", + wantStream: false, + wantMetadataUID: "", + wantHasSystem: false, + wantThinking: false, + wantMaxTokens: 0, + wantMessagesNil: true, + }, + { + name: "metadata 无 user_id", + body: `{"model":"test"}`, + wantModel: "test", + wantMetadataUID: "", + wantHasSystem: false, + wantThinking: false, + }, + { + name: "thinking 非 enabled(type=disabled)", + body: `{"model":"test","thinking":{"type":"disabled"}}`, + wantModel: "test", + wantThinking: false, + }, + { + name: "thinking 字段缺失", + body: `{"model":"test"}`, + wantModel: "test", + wantThinking: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parsed, err := ParseGatewayRequest([]byte(tt.body), "") + require.NoError(t, err) + + require.Equal(t, tt.wantModel, parsed.Model) + require.Equal(t, tt.wantStream, parsed.Stream) + require.Equal(t, tt.wantMetadataUID, parsed.MetadataUserID) + require.Equal(t, tt.wantHasSystem, parsed.HasSystem) + require.Equal(t, tt.wantThinking, parsed.ThinkingEnabled) + require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens) + + if tt.wantMessagesNil { + require.Nil(t, parsed.Messages) + } + if tt.wantMessagesLen > 0 { + require.Len(t, parsed.Messages, tt.wantMessagesLen) + } + }) + } +} + +// Task 7.3 — Gemini 协议分支测试 +// 已有测试覆盖: +// - TestParseGatewayRequest_GeminiSystemInstruction: 正常 systemInstruction+contents +// - TestParseGatewayRequest_GeminiNoContents: 缺失 contents +// - TestParseGatewayRequest_GeminiContents: 正常 contents(无 systemInstruction) +// 因此跳过。 + +// Task 7.4 — max_tokens 边界测试 +func TestParseGatewayRequest_MaxTokensBoundary(t *testing.T) { + tests := []struct { + name string + body string + wantMaxTokens int + wantErr bool + }{ + { + name: "正常整数", + body: `{"max_tokens":1024}`, + wantMaxTokens: 1024, + }, + { + name: "浮点数(非整数)被忽略", + body: `{"max_tokens":10.5}`, + wantMaxTokens: 0, + }, + { + name: "负整数可以通过", + body: `{"max_tokens":-1}`, + wantMaxTokens: -1, + }, + { + name: "超大值不 panic", + body: `{"max_tokens":9999999999999999}`, + wantMaxTokens: 10000000000000000, // float64 精度导致 9999999999999999 → 1e16 + }, + { + name: "null 值被忽略", + body: `{"max_tokens":null}`, + wantMaxTokens: 0, // gjson Type=Null != Number → 条件不满足,跳过 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parsed, err := ParseGatewayRequest([]byte(tt.body), "") + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens) + }) + } +} + +// ============ Task 7.5: Benchmark 测试 ============ + +// parseGatewayRequestOld 是基于完整 json.Unmarshal 的旧实现,用于 benchmark 对比基线。 +// 核心路径:先 Unmarshal 到 map[string]any,再逐字段提取。 +func parseGatewayRequestOld(body []byte, protocol string) (*ParsedRequest, error) { + parsed := &ParsedRequest{ + Body: body, + } + + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + // model + if raw, ok := req["model"]; ok { + s, ok := raw.(string) + if !ok { + return nil, fmt.Errorf("invalid model field type") + } + parsed.Model = s + } + + // stream + if raw, ok := req["stream"]; ok { + b, ok := raw.(bool) + if !ok { + return nil, fmt.Errorf("invalid stream field type") + } + parsed.Stream = b + } + + // metadata.user_id + if meta, ok := req["metadata"].(map[string]any); ok { + if uid, ok := meta["user_id"].(string); ok { + parsed.MetadataUserID = uid + } + } + + // thinking.type + if thinking, ok := req["thinking"].(map[string]any); ok { + if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" { + parsed.ThinkingEnabled = true + } + } + + // max_tokens + if raw, ok := req["max_tokens"]; ok { + if n, ok := parseIntegralNumber(raw); ok { + parsed.MaxTokens = n + } + } + + // system / messages(按协议分支) + switch protocol { + case domain.PlatformGemini: + if sysInst, ok := req["systemInstruction"].(map[string]any); ok { + if parts, ok := sysInst["parts"].([]any); ok { + parsed.System = parts + } + } + if contents, ok := req["contents"].([]any); ok { + parsed.Messages = contents + } + default: + if system, ok := req["system"]; ok { + parsed.HasSystem = true + parsed.System = system + } + if messages, ok := req["messages"].([]any); ok { + parsed.Messages = messages + } + } + + return parsed, nil +} + +// buildSmallJSON 构建 ~500B 的小型测试 JSON +func buildSmallJSON() []byte { + return []byte(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":4096,"metadata":{"user_id":"user-abc123"},"thinking":{"type":"enabled","budget_tokens":2048},"system":"You are a helpful assistant.","messages":[{"role":"user","content":"What is the meaning of life?"},{"role":"assistant","content":"The meaning of life is a philosophical question."},{"role":"user","content":"Can you elaborate?"}]}`) +} + +// buildLargeJSON 构建 ~50KB 的大型测试 JSON(大量 messages) +func buildLargeJSON() []byte { + var b strings.Builder + b.WriteString(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":8192,"metadata":{"user_id":"user-xyz789"},"system":[{"type":"text","text":"You are a detailed assistant.","cache_control":{"type":"ephemeral"}}],"messages":[`) + + msgCount := 200 + for i := 0; i < msgCount; i++ { + if i > 0 { + b.WriteByte(',') + } + if i%2 == 0 { + b.WriteString(fmt.Sprintf(`{"role":"user","content":"This is user message number %d with some extra padding text to make the message reasonably long for benchmarking purposes. Lorem ipsum dolor sit amet."}`, i)) + } else { + b.WriteString(fmt.Sprintf(`{"role":"assistant","content":[{"type":"text","text":"This is assistant response number %d. I will provide a detailed answer with multiple sentences to simulate real conversation content for benchmark testing."}]}`, i)) + } + } + + b.WriteString(`]}`) + return []byte(b.String()) +} + +func BenchmarkParseGatewayRequest_Old_Small(b *testing.B) { + data := buildSmallJSON() + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = parseGatewayRequestOld(data, "") + } +} + +func BenchmarkParseGatewayRequest_New_Small(b *testing.B) { + data := buildSmallJSON() + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ParseGatewayRequest(data, "") + } +} + +func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) { + data := buildLargeJSON() + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = parseGatewayRequestOld(data, "") + } +} + +func BenchmarkParseGatewayRequest_New_Large(b *testing.B) { + data := buildLargeJSON() + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ParseGatewayRequest(data, "") + } +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 32646b11..48c69881 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5,18 +5,17 @@ import ( "bytes" "context" "crypto/sha256" - "encoding/hex" "encoding/json" "errors" "fmt" "io" - "log" "log/slog" mathrand "math/rand" "net/http" "os" "regexp" "sort" + "strconv" "strings" "sync/atomic" "time" @@ -24,11 +23,16 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "github.com/cespare/xxhash/v2" "github.com/google/uuid" + gocache "github.com/patrickmn/go-cache" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + "golang.org/x/sync/singleflight" "github.com/gin-gonic/gin" ) @@ -43,6 +47,9 @@ const ( // separator between system blocks, we add "\n\n" at concatenation time. claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 + + defaultUserGroupRateCacheTTL = 30 * time.Second + defaultModelsListCacheTTL = 15 * time.Second ) const ( @@ -61,6 +68,53 @@ type accountWithLoad struct { var ForceCacheBillingContextKey = forceCacheBillingKeyType{} +var ( + windowCostPrefetchCacheHitTotal atomic.Int64 + windowCostPrefetchCacheMissTotal atomic.Int64 + windowCostPrefetchBatchSQLTotal atomic.Int64 + windowCostPrefetchFallbackTotal atomic.Int64 + windowCostPrefetchErrorTotal atomic.Int64 + + userGroupRateCacheHitTotal atomic.Int64 + userGroupRateCacheMissTotal atomic.Int64 + userGroupRateCacheLoadTotal atomic.Int64 + userGroupRateCacheSFSharedTotal atomic.Int64 + userGroupRateCacheFallbackTotal atomic.Int64 + + modelsListCacheHitTotal atomic.Int64 + modelsListCacheMissTotal atomic.Int64 + modelsListCacheStoreTotal atomic.Int64 +) + +func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) { + return windowCostPrefetchCacheHitTotal.Load(), + windowCostPrefetchCacheMissTotal.Load(), + windowCostPrefetchBatchSQLTotal.Load(), + windowCostPrefetchFallbackTotal.Load(), + windowCostPrefetchErrorTotal.Load() +} + +func GatewayUserGroupRateCacheStats() (cacheHit, cacheMiss, load, singleflightShared, fallback int64) { + return userGroupRateCacheHitTotal.Load(), + userGroupRateCacheMissTotal.Load(), + userGroupRateCacheLoadTotal.Load(), + userGroupRateCacheSFSharedTotal.Load(), + userGroupRateCacheFallbackTotal.Load() +} + +func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { + return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() +} + +func cloneStringSlice(src []string) []string { + if len(src) == 0 { + return nil + } + dst := make([]string, len(src)) + copy(dst, src) + return dst +} + // IsForceCacheBilling 检查是否启用强制缓存计费 func IsForceCacheBilling(ctx context.Context) bool { v, _ := ctx.Value(ForceCacheBillingContextKey).(bool) @@ -73,13 +127,26 @@ func WithForceCacheBilling(ctx context.Context) context.Context { } func (s *GatewayService) debugModelRoutingEnabled() bool { - v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) - return v == "1" || v == "true" || v == "yes" || v == "on" + if s == nil { + return false + } + return s.debugModelRouting.Load() } func (s *GatewayService) debugClaudeMimicEnabled() bool { - v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) - return v == "1" || v == "true" || v == "yes" || v == "on" + if s == nil { + return false + } + return s.debugClaudeMimic.Load() +} + +func parseDebugEnvBool(raw string) bool { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "1", "true", "yes", "on": + return true + default: + return false + } } func shortSessionHash(sessionHash string) string { @@ -212,7 +279,7 @@ func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, token if line == "" { return } - log.Printf("[ClaudeMimicDebug] %s", line) + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebug] %s", line) } func isClaudeCodeCredentialScopeError(msg string) bool { @@ -242,12 +309,15 @@ var ( } ) +// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表 +// OAuth/SetupToken 账号转发时,匹配这些前缀的 system 元素会被移除 +var systemBlockFilterPrefixes = []string{ + "x-anthropic-billing-header", +} + // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") -// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内 -var ErrModelScopeNotSupported = errors.New("model scope not supported by this group") - // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -273,13 +343,6 @@ var allowedHeaders = map[string]bool{ // GatewayCache 定义网关服务的缓存操作接口。 // 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。 // -// ModelLoadInfo 模型负载信息(用于 Antigravity 调度) -// Model load info for Antigravity scheduling -type ModelLoadInfo struct { - CallCount int64 // 当前分钟调用次数 / Call count in current minute - LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled) -} - // GatewayCache defines cache operations for gateway service. // Provides sticky session storage, retrieval, refresh and deletion capabilities. type GatewayCache interface { @@ -295,24 +358,6 @@ type GatewayCache interface { // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 // Delete sticky session binding, used to proactively clean up when account becomes unavailable DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error - - // IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用) - // Increment model call count and update last scheduling time (Antigravity only) - // 返回更新后的调用次数 - IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) - - // GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用) - // Batch get model load info for accounts (Antigravity only) - GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) - - // FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配) - // Find Gemini session using MGET reverse order matching - // 返回最长匹配的会话信息(uuid, accountID) - FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) - - // SaveGeminiSession 保存 Gemini 会话 - // Save Gemini session binding - SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error } // derefGroupID safely dereferences *int64 to int64, returning 0 if nil @@ -323,21 +368,48 @@ func derefGroupID(groupID *int64) int64 { return *groupID } -// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。 -// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。 -// 低于此阈值时保持粘性会话,等待短暂限流结束。 -const stickySessionRateLimitThreshold = 10 * time.Second +func resolveUserGroupRateCacheTTL(cfg *config.Config) time.Duration { + if cfg == nil || cfg.Gateway.UserGroupRateCacheTTLSeconds <= 0 { + return defaultUserGroupRateCacheTTL + } + return time.Duration(cfg.Gateway.UserGroupRateCacheTTLSeconds) * time.Second +} + +func resolveModelsListCacheTTL(cfg *config.Config) time.Duration { + if cfg == nil || cfg.Gateway.ModelsListCacheTTLSeconds <= 0 { + return defaultModelsListCacheTTL + } + return time.Duration(cfg.Gateway.ModelsListCacheTTLSeconds) * time.Second +} + +func modelsListCacheKey(groupID *int64, platform string) string { + return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform)) +} + +func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { + return PrefetchedStickyGroupIDFromContext(ctx) +} + +func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 { + prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx) + if !ok || prefetchedGroupID != derefGroupID(groupID) { + return 0 + } + if accountID, ok := PrefetchedStickyAccountIDFromContext(ctx); ok && accountID > 0 { + return accountID + } + return 0 +} // shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 // 当账号状态为错误、禁用、不可调度、处于临时不可调度期间, -// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。 +// 或请求的模型处于限流状态时,返回 true。 // 这确保后续请求不会继续使用不可用的账号。 // // shouldClearStickySession checks if an account is in an unschedulable state // and the sticky session binding should be cleared. // Returns true when account status is error/disabled, schedulable is false, -// within temporary unschedulable period, or model rate limit remaining time -// exceeds stickySessionRateLimitThreshold. +// within temporary unschedulable period, or the requested model is rate-limited. // This ensures subsequent requests won't continue using unavailable accounts. func shouldClearStickySession(account *Account, requestedModel string) bool { if account == nil { @@ -349,8 +421,8 @@ func shouldClearStickySession(account *Account, requestedModel string) bool { if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { return true } - // 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话 - if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold { + // 检查模型限流和 scope 限流,有限流即清除粘性会话 + if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 { return true } return false @@ -376,6 +448,8 @@ type ClaudeUsage struct { OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"` + CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象) + CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象) } // ForwardResult 转发结果 @@ -388,42 +462,72 @@ type ForwardResult struct { FirstTokenMs *int // 首字时间(流式请求) ClientDisconnect bool // 客户端是否在流式传输过程中断开 - // 图片生成计费字段(仅 gemini-3-pro-image 使用) + // 图片生成计费字段(图片生成模型使用) ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" + + // Sora 媒体字段 + MediaType string // image / video / prompt + MediaURL string // 生成后的媒体地址(可选) } // UpstreamFailoverError indicates an upstream error that should trigger account failover. type UpstreamFailoverError struct { - StatusCode int - ResponseBody []byte // 上游响应体,用于错误透传规则匹配 - ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true + StatusCode int + ResponseBody []byte // 上游响应体,用于错误透传规则匹配 + ResponseHeaders http.Header // 上游响应头,用于透传 cf-ray/cf-mitigated/content-type 等诊断信息 + ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true + RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 } func (e *UpstreamFailoverError) Error() string { return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode) } +// TempUnscheduleRetryableError 对 RetryableOnSameAccount 类型的 failover 错误触发临时封禁。 +// 由 handler 层在同账号重试全部用尽、切换账号时调用。 +func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *UpstreamFailoverError) { + if failoverErr == nil || !failoverErr.RetryableOnSameAccount { + return + } + // 根据状态码选择封禁策略 + switch failoverErr.StatusCode { + case http.StatusBadRequest: + tempUnscheduleGoogleConfigError(ctx, s.accountRepo, accountID, "[handler]") + case http.StatusBadGateway: + tempUnscheduleEmptyResponse(ctx, s.accountRepo, accountID, "[handler]") + } +} + // GatewayService handles API gateway operations type GatewayService struct { - accountRepo AccountRepository - groupRepo GroupRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - userGroupRateRepo UserGroupRateRepository - cache GatewayCache - cfg *config.Config - schedulerSnapshot *SchedulerSnapshotService - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - identityService *IdentityService - httpUpstream HTTPUpstream - deferredService *DeferredService - concurrencyService *ConcurrencyService - claudeTokenProvider *ClaudeTokenProvider - sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + accountRepo AccountRepository + groupRepo GroupRepository + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache GatewayCache + digestStore *DigestSessionStore + cfg *config.Config + schedulerSnapshot *SchedulerSnapshotService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + identityService *IdentityService + httpUpstream HTTPUpstream + deferredService *DeferredService + concurrencyService *ConcurrencyService + claudeTokenProvider *ClaudeTokenProvider + sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) + userGroupRateCache *gocache.Cache + userGroupRateSF singleflight.Group + modelsListCache *gocache.Cache + modelsListCacheTTL time.Duration + responseHeaderFilter *responseheaders.CompiledHeaderFilter + debugModelRouting atomic.Bool + debugClaudeMimic atomic.Bool } // NewGatewayService creates a new GatewayService @@ -446,27 +550,41 @@ func NewGatewayService( deferredService *DeferredService, claudeTokenProvider *ClaudeTokenProvider, sessionLimitCache SessionLimitCache, + rpmCache RPMCache, + digestStore *DigestSessionStore, ) *GatewayService { - return &GatewayService{ - accountRepo: accountRepo, - groupRepo: groupRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - userGroupRateRepo: userGroupRateRepo, - cache: cache, - cfg: cfg, - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - identityService: identityService, - httpUpstream: httpUpstream, - deferredService: deferredService, - claudeTokenProvider: claudeTokenProvider, - sessionLimitCache: sessionLimitCache, + userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) + modelsListTTL := resolveModelsListCacheTTL(cfg) + + svc := &GatewayService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, + cache: cache, + digestStore: digestStore, + cfg: cfg, + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + identityService: identityService, + httpUpstream: httpUpstream, + deferredService: deferredService, + claudeTokenProvider: claudeTokenProvider, + sessionLimitCache: sessionLimitCache, + rpmCache: rpmCache, + userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), + modelsListCache: gocache.New(modelsListTTL, time.Minute), + modelsListCacheTTL: modelsListTTL, + responseHeaderFilter: compileResponseHeaderFilter(cfg), } + svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) + svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) + return svc } // GenerateSessionHash 从预解析请求计算粘性会话 hash @@ -488,23 +606,45 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { return s.hashContent(cacheableContent) } - // 3. Fallback: 使用 system 内容 + // 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串 + var combined strings.Builder + // 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash + if parsed.SessionContext != nil { + _, _ = combined.WriteString(parsed.SessionContext.ClientIP) + _, _ = combined.WriteString(":") + _, _ = combined.WriteString(parsed.SessionContext.UserAgent) + _, _ = combined.WriteString(":") + _, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10)) + _, _ = combined.WriteString("|") + } if parsed.System != nil { systemText := s.extractTextFromSystem(parsed.System) if systemText != "" { - return s.hashContent(systemText) + _, _ = combined.WriteString(systemText) } } - - // 4. 最后 fallback: 使用第一条消息 - if len(parsed.Messages) > 0 { - if firstMsg, ok := parsed.Messages[0].(map[string]any); ok { - msgText := s.extractTextFromContent(firstMsg["content"]) - if msgText != "" { - return s.hashContent(msgText) + for _, msg := range parsed.Messages { + if m, ok := msg.(map[string]any); ok { + if content, exists := m["content"]; exists { + // Anthropic: messages[].content + if msgText := s.extractTextFromContent(content); msgText != "" { + _, _ = combined.WriteString(msgText) + } + } else if parts, ok := m["parts"].([]any); ok { + // Gemini: contents[].parts[].text + for _, part := range parts { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok { + _, _ = combined.WriteString(text) + } + } + } } } } + if combined.Len() > 0 { + return s.hashContent(combined.String()) + } return "" } @@ -532,19 +672,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID // FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配) // 返回最长匹配的会话信息(uuid, accountID) -func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - if digestChain == "" || s.cache == nil { - return "", 0, false +func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" || s.digestStore == nil { + return "", 0, "", false } - return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain) + return s.digestStore.Find(groupID, prefixHash, digestChain) } -// SaveGeminiSession 保存 Gemini 会话 -func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - if digestChain == "" || s.cache == nil { +// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。 +func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { + if digestChain == "" || s.digestStore == nil { return nil } - return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) + s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) + return nil +} + +// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配) +func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" || s.digestStore == nil { + return "", 0, "", false + } + return s.digestStore.Find(groupID, prefixHash, digestChain) +} + +// SaveAnthropicSession 保存 Anthropic 会话 +func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { + if digestChain == "" || s.digestStore == nil { + return nil + } + s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) + return nil } func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { @@ -629,8 +787,8 @@ func (s *GatewayService) extractTextFromContent(content any) string { } func (s *GatewayService) hashContent(content string) string { - hash := sha256.Sum256([]byte(content)) - return hex.EncodeToString(hash[:16]) // 32字符 + h := xxhash.Sum64String(content) + return strconv.FormatUint(h, 36) } // replaceModelInBody 替换请求体中的model字段 @@ -897,13 +1055,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro cfg := s.schedulingConfig() - var stickyAccountID int64 - if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { - stickyAccountID = accountID - } - } - // 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组) group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) if err != nil { @@ -911,12 +1062,21 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } ctx = s.withGroupContext(ctx, group) + var stickyAccountID int64 + if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { + stickyAccountID = prefetch + } else if sessionHash != "" && s.cache != nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { + stickyAccountID = accountID + } + } + if s.debugModelRoutingEnabled() && requestedModel != "" { groupPlatform := "" if group != nil { groupPlatform = group.Platform } - log.Printf("[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v", derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil) } @@ -986,14 +1146,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } preferOAuth := platform == PlatformGemini if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" { - log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) - } - - // Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查) - if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { - if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { - return nil, err - } + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) } accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) @@ -1003,6 +1156,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if len(accounts) == 0 { return nil, errors.New("no available accounts") } + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) isExcluded := func(accountID int64) bool { if excludedIDs == nil { @@ -1023,7 +1178,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic { routingAccountIDs = group.GetRoutingAccountIDs(requestedModel) if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d", group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID) if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 { keys := make([]string, 0, len(group.ModelRouting)) @@ -1035,7 +1190,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if len(keys) > maxKeys { keys = keys[:maxKeys] } - log.Printf("[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys) } } } @@ -1052,7 +1207,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro continue } account, ok := accountByID[routingAccountID] - if !ok || !account.IsSchedulable() { + if !ok || !s.isAccountSchedulableForSelection(account) { if !ok { filteredMissing++ } else { @@ -1068,7 +1223,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro filteredModelMapping++ continue } - if !account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { filteredModelScope++ modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID) continue @@ -1078,31 +1233,36 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro filteredWindowCost++ continue } + // RPM 检查(非粘性会话路径) + if !s.isAccountSchedulableForRPM(ctx, account, false) { + continue + } routingCandidates = append(routingCandidates, account) } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates), filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost) if len(modelScopeSkippedIDs) > 0 { - log.Printf("[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v", derefGroupID(groupID), requestedModel, modelScopeSkippedIDs) } } if len(routingCandidates) > 0 { // 1.5. 在路由账号范围内检查粘性会话 - if sessionHash != "" && s.cache != nil { - stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { + if sessionHash != "" && stickyAccountID > 0 { + if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { - if stickyAccount.IsSchedulable() && + if s.isAccountSchedulableForSelection(stickyAccount) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && - stickyAccount.IsSchedulableForModelWithContext(ctx, requestedModel) && - s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查 + s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && + + s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 @@ -1110,9 +1270,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result.ReleaseFunc() // 释放槽位 // 继续到负载感知选择 } else { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } return &AccountSelectionResult{ Account: stickyAccount, @@ -1190,6 +1349,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return a.account.LastUsedAt.Before(*b.account.LastUsedAt) } }) + shuffleWithinSortGroups(routingAvailable) // 4. 尝试获取槽位 for _, item := range routingAvailable { @@ -1204,7 +1364,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } return &AccountSelectionResult{ Account: item.account, @@ -1221,7 +1381,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro continue // 会话限制已满,尝试下一个 } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } return &AccountSelectionResult{ Account: item.account, @@ -1236,14 +1396,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // 所有路由账号会话限制都已满,继续到 Layer 2 回退 } // 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退 - log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) + logger.LegacyPrintf("service.gateway", "[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) } } // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============ - if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { + if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { + accountID := stickyAccountID + if accountID > 0 && !isExcluded(accountID) { account, ok := accountByID[accountID] if ok { // 检查账户是否需要清理粘性会话绑定 @@ -1255,8 +1415,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !clearSticky && s.isAccountInGroup(account, groupID) && s.isAccountAllowedForPlatform(account, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && - account.IsSchedulableForModelWithContext(ctx, requestedModel) && - s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查 + s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && + s.isAccountSchedulableForWindowCost(ctx, account, true) && + + s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 @@ -1264,7 +1426,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 } else { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, @@ -1307,7 +1468,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); // re-check schedulability here so recently rate-limited/overloaded accounts // are not selected again before the bucket is rebuilt. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { @@ -1316,13 +1477,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } // 窗口费用检查(非粘性会话路径) if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } + // RPM 检查(非粘性会话路径) + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } candidates = append(candidates, acc) } @@ -1344,10 +1509,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return result, nil } } else { - // Antigravity 平台:获取模型负载信息 - var modelLoadMap map[int64]*ModelLoadInfo - isAntigravity := platform == PlatformAntigravity - var available []accountWithLoad for _, acc := range candidates { loadInfo := loadMap[acc.ID] @@ -1362,109 +1523,44 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致) - if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 { - modelLoadMap = make(map[int64]*ModelLoadInfo, len(available)) - modelToAccountIDs := make(map[string][]int64) - for _, item := range available { - mappedModel := mapAntigravityModel(item.account, requestedModel) - if mappedModel == "" { - continue - } - modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID) + // 分层过滤选择:优先级 → 负载率 → LRU + for len(available) > 0 { + // 1. 取优先级最小的集合 + candidates := filterByMinPriority(available) + // 2. 取负载率最低的集合 + candidates = filterByMinLoadRate(candidates) + // 3. LRU 选择最久未用的账号 + selected := selectByLRU(candidates, preferOAuth) + if selected == nil { + break } - for model, ids := range modelToAccountIDs { - batch, err := s.cache.GetModelLoadBatch(ctx, ids, model) - if err != nil { - continue - } - for id, info := range batch { - modelLoadMap[id] = info - } - } - if len(modelLoadMap) == 0 { - modelLoadMap = nil - } - } - // Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值) - // 其他平台:分层过滤选择:优先级 → 负载率 → LRU - if isAntigravity { - for len(available) > 0 { - // 1. 取优先级最小的集合(硬过滤) - candidates := filterByMinPriority(available) - // 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值) - selected := selectByCallCount(candidates, modelLoadMap, preferOAuth) - if selected == nil { - break - } - - result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) - if err == nil && result.Acquired { - // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { - result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 - } else { - if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: selected.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) + if err == nil && result.Acquired { + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { + result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 + } else { + if sessionHash != "" && s.cache != nil { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } + return &AccountSelectionResult{ + Account: selected.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } - - // 移除已尝试的账号,重新选择 - selectedID := selected.account.ID - newAvailable := make([]accountWithLoad, 0, len(available)-1) - for _, acc := range available { - if acc.account.ID != selectedID { - newAvailable = append(newAvailable, acc) - } - } - available = newAvailable } - } else { - for len(available) > 0 { - // 1. 取优先级最小的集合 - candidates := filterByMinPriority(available) - // 2. 取负载率最低的集合 - candidates = filterByMinLoadRate(candidates) - // 3. LRU 选择最久未用的账号 - selected := selectByLRU(candidates, preferOAuth) - if selected == nil { - break - } - result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) - if err == nil && result.Acquired { - // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { - result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 - } else { - if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: selected.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } + // 移除已尝试的账号,重新进行分层过滤 + selectedID := selected.account.ID + newAvailable := make([]accountWithLoad, 0, len(available)-1) + for _, acc := range available { + if acc.account.ID != selectedID { + newAvailable = append(newAvailable, acc) } - - // 移除已尝试的账号,重新进行分层过滤 - selectedID := selected.account.ID - newAvailable := make([]accountWithLoad, 0, len(available)-1) - for _, acc := range available { - if acc.account.ID != selectedID { - newAvailable = append(newAvailable, acc) - } - } - available = newAvailable } + available = newAvailable } } @@ -1567,20 +1663,20 @@ func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupI group, err := s.resolveGroupByID(ctx, *groupID) if err != nil || group == nil { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err) } return nil } // Preserve existing behavior: model routing only applies to anthropic groups. if group.Platform != PlatformAnthropic { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel) } return nil } ids := group.GetRoutingAccountIDs(requestedModel) if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v", group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids) } return ids @@ -1656,6 +1752,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr } func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { + if platform == PlatformSora { + return s.listSoraSchedulableAccounts(ctx, groupID) + } if s.schedulerSnapshot != nil { accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err == nil { @@ -1750,6 +1849,64 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i return accounts, useMixed, nil } +func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) { + const useMixed = false + + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } else if groupID != nil { + accounts, err = s.accountRepo.ListByGroup(ctx, *groupID) + } else { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } + if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "error", err) + return nil, useMixed, err + } + + filtered := make([]Account, 0, len(accounts)) + for _, acc := range accounts { + if acc.Platform != PlatformSora { + continue + } + if !s.isSoraAccountSchedulable(&acc) { + continue + } + filtered = append(filtered, acc) + } + slog.Debug("account_scheduling_list_sora", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "raw_count", len(accounts), + "filtered_count", len(filtered)) + for _, acc := range filtered { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + return filtered, useMixed, nil +} + +// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 +// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, +// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 +func (s *GatewayService) IsSingleAntigravityAccountGroup(ctx context.Context, groupID *int64) bool { + accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAntigravity, true) + if err != nil { + return false + } + return len(accounts) == 1 +} + func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool { if account == nil { return false @@ -1763,6 +1920,49 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform return account.Platform == platform } +func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool { + return s.soraUnschedulableReason(account) == "" +} + +func (s *GatewayService) soraUnschedulableReason(account *Account) string { + if account == nil { + return "account_nil" + } + if account.Status != StatusActive { + return fmt.Sprintf("status=%s", account.Status) + } + if !account.Schedulable { + return "schedulable=false" + } + if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { + return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339)) + } + return "" +} + +func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { + if account == nil { + return false + } + if account.Platform == PlatformSora { + return s.isSoraAccountSchedulable(account) + } + return account.IsSchedulable() +} + +func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool { + if account == nil { + return false + } + if account.Platform == PlatformSora { + if !s.isSoraAccountSchedulable(account) { + return false + } + return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0 + } + return account.IsSchedulableForModelWithContext(ctx, requestedModel) +} + // isAccountInGroup checks if the account belongs to the specified group. // Returns true if groupID is nil (no group restriction) or account belongs to the group. func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { @@ -1787,6 +1987,129 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } +type usageLogWindowStatsBatchProvider interface { + GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) +} + +type windowCostPrefetchContextKeyType struct{} + +var windowCostPrefetchContextKey = windowCostPrefetchContextKeyType{} + +func windowCostFromPrefetchContext(ctx context.Context, accountID int64) (float64, bool) { + if ctx == nil || accountID <= 0 { + return 0, false + } + m, ok := ctx.Value(windowCostPrefetchContextKey).(map[int64]float64) + if !ok || len(m) == 0 { + return 0, false + } + v, exists := m[accountID] + return v, exists +} + +func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []Account) context.Context { + if ctx == nil || len(accounts) == 0 || s.sessionLimitCache == nil || s.usageLogRepo == nil { + return ctx + } + + accountByID := make(map[int64]*Account) + accountIDs := make([]int64, 0, len(accounts)) + for i := range accounts { + account := &accounts[i] + if account == nil || !account.IsAnthropicOAuthOrSetupToken() { + continue + } + if account.GetWindowCostLimit() <= 0 { + continue + } + accountByID[account.ID] = account + accountIDs = append(accountIDs, account.ID) + } + if len(accountIDs) == 0 { + return ctx + } + + costs := make(map[int64]float64, len(accountIDs)) + cacheValues, err := s.sessionLimitCache.GetWindowCostBatch(ctx, accountIDs) + if err == nil { + for accountID, cost := range cacheValues { + costs[accountID] = cost + } + windowCostPrefetchCacheHitTotal.Add(int64(len(cacheValues))) + } else { + windowCostPrefetchErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "window_cost batch cache read failed: %v", err) + } + cacheMissCount := len(accountIDs) - len(costs) + if cacheMissCount < 0 { + cacheMissCount = 0 + } + windowCostPrefetchCacheMissTotal.Add(int64(cacheMissCount)) + + missingByStart := make(map[int64][]int64) + startTimes := make(map[int64]time.Time) + for _, accountID := range accountIDs { + if _, ok := costs[accountID]; ok { + continue + } + account := accountByID[accountID] + if account == nil { + continue + } + startTime := account.GetCurrentWindowStartTime() + startKey := startTime.Unix() + missingByStart[startKey] = append(missingByStart[startKey], accountID) + startTimes[startKey] = startTime + } + if len(missingByStart) == 0 { + return context.WithValue(ctx, windowCostPrefetchContextKey, costs) + } + + batchReader, hasBatch := s.usageLogRepo.(usageLogWindowStatsBatchProvider) + for startKey, ids := range missingByStart { + startTime := startTimes[startKey] + + if hasBatch { + windowCostPrefetchBatchSQLTotal.Add(1) + queryStart := time.Now() + statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, ids, startTime) + if err == nil { + slog.Debug("window_cost_batch_query_ok", + "accounts", len(ids), + "window_start", startTime.Format(time.RFC3339), + "duration_ms", time.Since(queryStart).Milliseconds()) + for _, accountID := range ids { + stats := statsByAccount[accountID] + cost := 0.0 + if stats != nil { + cost = stats.StandardCost + } + costs[accountID] = cost + _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) + } + continue + } + windowCostPrefetchErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "window_cost batch db query failed: start=%s err=%v", startTime.Format(time.RFC3339), err) + } + + // 回退路径:缺少批量仓储能力或批量查询失败时,按账号单查(失败开放)。 + windowCostPrefetchFallbackTotal.Add(int64(len(ids))) + for _, accountID := range ids { + stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime) + if err != nil { + windowCostPrefetchErrorTotal.Add(1) + continue + } + cost := stats.StandardCost + costs[accountID] = cost + _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) + } + } + + return context.WithValue(ctx, windowCostPrefetchContextKey, costs) +} + // isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度 // 仅适用于 Anthropic OAuth/SetupToken 账号 // 返回 true 表示可调度,false 表示不可调度 @@ -1803,6 +2126,10 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, // 尝试从缓存获取窗口费用 var currentCost float64 + if cost, ok := windowCostFromPrefetchContext(ctx, account.ID); ok { + currentCost = cost + goto checkSchedulability + } if s.sessionLimitCache != nil { if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit { currentCost = cost @@ -1844,6 +2171,88 @@ checkSchedulability: return true } +// rpmPrefetchContextKey is the context key for prefetched RPM counts. +type rpmPrefetchContextKeyType struct{} + +var rpmPrefetchContextKey = rpmPrefetchContextKeyType{} + +func rpmFromPrefetchContext(ctx context.Context, accountID int64) (int, bool) { + if v, ok := ctx.Value(rpmPrefetchContextKey).(map[int64]int); ok { + count, found := v[accountID] + return count, found + } + return 0, false +} + +// withRPMPrefetch 批量预取所有候选账号的 RPM 计数 +func (s *GatewayService) withRPMPrefetch(ctx context.Context, accounts []Account) context.Context { + if s.rpmCache == nil { + return ctx + } + + var ids []int64 + for i := range accounts { + if accounts[i].IsAnthropicOAuthOrSetupToken() && accounts[i].GetBaseRPM() > 0 { + ids = append(ids, accounts[i].ID) + } + } + if len(ids) == 0 { + return ctx + } + + counts, err := s.rpmCache.GetRPMBatch(ctx, ids) + if err != nil { + return ctx // 失败开放 + } + return context.WithValue(ctx, rpmPrefetchContextKey, counts) +} + +// isAccountSchedulableForRPM 检查账号是否可根据 RPM 进行调度 +// 仅适用于 Anthropic OAuth/SetupToken 账号 +func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account *Account, isSticky bool) bool { + if !account.IsAnthropicOAuthOrSetupToken() { + return true + } + baseRPM := account.GetBaseRPM() + if baseRPM <= 0 { + return true + } + + // 尝试从预取缓存获取 + var currentRPM int + if count, ok := rpmFromPrefetchContext(ctx, account.ID); ok { + currentRPM = count + } else if s.rpmCache != nil { + if count, err := s.rpmCache.GetRPM(ctx, account.ID); err == nil { + currentRPM = count + } + // 失败开放:GetRPM 错误时允许调度 + } + + schedulability := account.CheckRPMSchedulability(currentRPM) + switch schedulability { + case WindowCostSchedulable: + return true + case WindowCostStickyOnly: + return isSticky + case WindowCostNotSchedulable: + return false + } + return true +} + +// IncrementAccountRPM increments the RPM counter for the given account. +// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口, +// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit +// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。 +func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error { + if s.rpmCache == nil { + return nil + } + _, err := s.rpmCache.IncrementRPM(ctx, accountID) + return err +} + // checkAndRegisterSession 检查并注册会话,用于会话数量限制 // 仅适用于 Anthropic OAuth/SetupToken 账号 // sessionID: 会话标识符(使用粘性会话的 hash) @@ -2000,87 +2409,104 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { return a.LastUsedAt.Before(*b.LastUsedAt) } }) + shuffleWithinPriorityAndLastUsed(accounts, preferOAuth) } -// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用) -// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调 -// 如果有多个账号具有相同的最小调用次数,则随机选择一个 -func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad { - if len(accounts) == 0 { - return nil +// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。 +// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。 +func shuffleWithinSortGroups(accounts []accountWithLoad) { + if len(accounts) <= 1 { + return } - if len(accounts) == 1 { - return &accounts[0] - } - - // 如果没有负载信息,回退到 LRU - if modelLoadMap == nil { - return selectByLRU(accounts, preferOAuth) - } - - // 1. 计算平均调用次数(用于新账号冷启动) - var totalCallCount int64 - var countWithCalls int - for _, acc := range accounts { - if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 { - totalCallCount += info.CallCount - countWithCalls++ + i := 0 + for i < len(accounts) { + j := i + 1 + for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) { + j++ } - } - - var avgCallCount int64 - if countWithCalls > 0 { - avgCallCount = totalCallCount / int64(countWithCalls) - } - - // 2. 获取每个账号的有效调用次数 - getEffectiveCallCount := func(acc accountWithLoad) int64 { - if acc.account == nil { - return 0 + if j-i > 1 { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) } - info := modelLoadMap[acc.account.ID] - if info == nil || info.CallCount == 0 { - return avgCallCount // 新账号使用平均值 + i = j + } +} + +// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组 +func sameAccountWithLoadGroup(a, b accountWithLoad) bool { + if a.account.Priority != b.account.Priority { + return false + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return false + } + return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt) +} + +// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。 +// +// 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。 +// 因此这里采用"组内分区 + 分区内 shuffle"的方式: +// - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前; +// - 再分别在各段内随机打散,避免热点。 +func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { + if len(accounts) <= 1 { + return + } + i := 0 + for i < len(accounts) { + j := i + 1 + for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) { + j++ } - return info.CallCount - } - - // 3. 找到最小调用次数 - minCount := getEffectiveCallCount(accounts[0]) - for _, acc := range accounts[1:] { - if c := getEffectiveCallCount(acc); c < minCount { - minCount = c - } - } - - // 4. 收集所有具有最小调用次数的账号 - var candidateIdxs []int - for i, acc := range accounts { - if getEffectiveCallCount(acc) == minCount { - candidateIdxs = append(candidateIdxs, i) - } - } - - // 5. 如果只有一个候选,直接返回 - if len(candidateIdxs) == 1 { - return &accounts[candidateIdxs[0]] - } - - // 6. preferOAuth 处理 - if preferOAuth { - var oauthIdxs []int - for _, idx := range candidateIdxs { - if accounts[idx].account.Type == AccountTypeOAuth { - oauthIdxs = append(oauthIdxs, idx) + if j-i > 1 { + if preferOAuth { + oauth := make([]*Account, 0, j-i) + others := make([]*Account, 0, j-i) + for _, acc := range accounts[i:j] { + if acc.Type == AccountTypeOAuth { + oauth = append(oauth, acc) + } else { + others = append(others, acc) + } + } + if len(oauth) > 1 { + mathrand.Shuffle(len(oauth), func(a, b int) { oauth[a], oauth[b] = oauth[b], oauth[a] }) + } + if len(others) > 1 { + mathrand.Shuffle(len(others), func(a, b int) { others[a], others[b] = others[b], others[a] }) + } + copy(accounts[i:], oauth) + copy(accounts[i+len(oauth):], others) + } else { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) } } - if len(oauthIdxs) > 0 { - candidateIdxs = oauthIdxs - } + i = j } +} - // 7. 随机选择 - return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]] +// sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt) +func sameAccountGroup(a, b *Account) bool { + if a.Priority != b.Priority { + return false + } + return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt) +} + +// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒) +func sameLastUsedAt(a, b *time.Time) bool { + switch { + case a == nil && b == nil: + return true + case a == nil || b == nil: + return false + default: + return a.Unix() == b.Unix() + } } // sortCandidatesForFallback 根据配置选择排序策略 @@ -2135,13 +2561,6 @@ func shuffleWithinPriority(accounts []*Account) { // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { - // 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内 - if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { - if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { - return nil, err - } - } - preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) @@ -2153,7 +2572,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // so switching model can switch upstream account within the same sticky session. if len(routingAccountIDs) > 0 { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs) } // 1) Sticky session only applies if the bound account is within the routing set. @@ -2168,12 +2587,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } return account, nil } @@ -2194,6 +2610,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } accountsLoaded = true + // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + routingSet := make(map[int64]struct{}, len(routingAccountIDs)) for _, id := range routingAccountIDs { if id > 0 { @@ -2212,13 +2632,19 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { continue } if selected == nil { @@ -2248,15 +2674,15 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if selected != nil { if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) } return selected, nil } - log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) + logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) } // 1. 查询粘性会话 @@ -2271,10 +2697,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { return account, nil } } @@ -2295,6 +2718,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } } + // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + // 3. 按优先级+最久未用选择(考虑模型支持) var selected *Account for i := range accounts { @@ -2304,13 +2731,19 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { continue } if selected == nil { @@ -2338,8 +2771,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } if selected == nil { + stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false) if requestedModel != "" { - return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) + return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats)) } return nil, errors.New("no available accounts") } @@ -2347,7 +2781,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // 4. 建立粘性绑定 if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } @@ -2366,7 +2800,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g // ============ Model Routing (legacy path): apply before sticky session ============ if len(routingAccountIDs) > 0 { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs) } // 1) Sticky session only applies if the bound account is within the routing set. @@ -2381,13 +2815,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } return account, nil } @@ -2405,6 +2836,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } accountsLoaded = true + // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + routingSet := make(map[int64]struct{}, len(routingAccountIDs)) for _, id := range routingAccountIDs { if id > 0 { @@ -2423,7 +2858,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 @@ -2433,7 +2868,13 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { continue } if selected == nil { @@ -2463,15 +2904,15 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if selected != nil { if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) } return selected, nil } - log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) + logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) } // 1. 查询粘性会话 @@ -2486,11 +2927,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } return account, nil } } @@ -2508,6 +2946,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } } + // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) var selected *Account for i := range accounts { @@ -2517,7 +2959,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 @@ -2527,7 +2969,13 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { continue } if selected == nil { @@ -2555,8 +3003,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } if selected == nil { + stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true) if requestedModel != "" { - return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) + return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats)) } return nil, errors.New("no available accounts") } @@ -2564,13 +3013,243 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g // 4. 建立粘性绑定 if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } return selected, nil } +type selectionFailureStats struct { + Total int + Eligible int + Excluded int + Unschedulable int + PlatformFiltered int + ModelUnsupported int + ModelRateLimited int + SamplePlatformIDs []int64 + SampleMappingIDs []int64 + SampleRateLimitIDs []string +} + +type selectionFailureDiagnosis struct { + Category string + Detail string +} + +func (s *GatewayService) logDetailedSelectionFailure( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + platform string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureStats { + stats := s.collectSelectionFailureStats(ctx, accounts, requestedModel, platform, excludedIDs, allowMixedScheduling) + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed] group_id=%v model=%s platform=%s session=%s total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d sample_platform_filtered=%v sample_model_unsupported=%v sample_model_rate_limited=%v", + derefGroupID(groupID), + requestedModel, + platform, + shortSessionHash(sessionHash), + stats.Total, + stats.Eligible, + stats.Excluded, + stats.Unschedulable, + stats.PlatformFiltered, + stats.ModelUnsupported, + stats.ModelRateLimited, + stats.SamplePlatformIDs, + stats.SampleMappingIDs, + stats.SampleRateLimitIDs, + ) + if platform == PlatformSora { + s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling) + } + return stats +} + +func (s *GatewayService) collectSelectionFailureStats( + ctx context.Context, + accounts []Account, + requestedModel string, + platform string, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureStats { + stats := selectionFailureStats{ + Total: len(accounts), + } + + for i := range accounts { + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, platform, excludedIDs, allowMixedScheduling) + switch diagnosis.Category { + case "excluded": + stats.Excluded++ + case "unschedulable": + stats.Unschedulable++ + case "platform_filtered": + stats.PlatformFiltered++ + stats.SamplePlatformIDs = appendSelectionFailureSampleID(stats.SamplePlatformIDs, acc.ID) + case "model_unsupported": + stats.ModelUnsupported++ + stats.SampleMappingIDs = appendSelectionFailureSampleID(stats.SampleMappingIDs, acc.ID) + case "model_rate_limited": + stats.ModelRateLimited++ + remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) + stats.SampleRateLimitIDs = appendSelectionFailureRateSample(stats.SampleRateLimitIDs, acc.ID, remaining) + default: + stats.Eligible++ + } + } + + return stats +} + +func (s *GatewayService) diagnoseSelectionFailure( + ctx context.Context, + acc *Account, + requestedModel string, + platform string, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureDiagnosis { + if acc == nil { + return selectionFailureDiagnosis{Category: "unschedulable", Detail: "account_nil"} + } + if _, excluded := excludedIDs[acc.ID]; excluded { + return selectionFailureDiagnosis{Category: "excluded"} + } + if !s.isAccountSchedulableForSelection(acc) { + detail := "generic_unschedulable" + if acc.Platform == PlatformSora { + detail = s.soraUnschedulableReason(acc) + } + return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} + } + if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { + return selectionFailureDiagnosis{ + Category: "platform_filtered", + Detail: fmt.Sprintf("account_platform=%s requested_platform=%s", acc.Platform, strings.TrimSpace(platform)), + } + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + return selectionFailureDiagnosis{ + Category: "model_unsupported", + Detail: fmt.Sprintf("model=%s", requestedModel), + } + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) + return selectionFailureDiagnosis{ + Category: "model_rate_limited", + Detail: fmt.Sprintf("remaining=%s", remaining), + } + } + return selectionFailureDiagnosis{Category: "eligible"} +} + +func (s *GatewayService) logSoraSelectionFailureDetails( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) { + const maxLines = 30 + logged := 0 + + for i := range accounts { + if logged >= maxLines { + break + } + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling) + if diagnosis.Category == "eligible" { + continue + } + detail := diagnosis.Detail + if detail == "" { + detail = "-" + } + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + acc.ID, + acc.Platform, + diagnosis.Category, + detail, + ) + logged++ + } + if len(accounts) > maxLines { + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + len(accounts), + logged, + ) + } +} + +func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { + if acc == nil { + return true + } + if allowMixedScheduling { + if acc.Platform == PlatformAntigravity { + return !acc.IsMixedSchedulingEnabled() + } + return acc.Platform != platform + } + if strings.TrimSpace(platform) == "" { + return false + } + return acc.Platform != platform +} + +func appendSelectionFailureSampleID(samples []int64, id int64) []int64 { + const limit = 5 + if len(samples) >= limit { + return samples + } + return append(samples, id) +} + +func appendSelectionFailureRateSample(samples []string, accountID int64, remaining time.Duration) []string { + const limit = 5 + if len(samples) >= limit { + return samples + } + return append(samples, fmt.Sprintf("%d(%s)", accountID, remaining)) +} + +func summarizeSelectionFailureStats(stats selectionFailureStats) string { + return fmt.Sprintf( + "total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d", + stats.Total, + stats.Eligible, + stats.Excluded, + stats.Unschedulable, + stats.PlatformFiltered, + stats.ModelUnsupported, + stats.ModelRateLimited, + ) +} + // isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context) // 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持 func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool { @@ -2584,7 +3263,7 @@ func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Contex return false } // 应用 thinking 后缀后检查最终模型是否在账号映射中 - if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + if enabled, ok := ThinkingEnabledFromContext(ctx); ok { finalModel := applyThinkingModelSuffix(mapped, enabled) if finalModel == mapped { return true // thinking 后缀未改变模型名,映射已通过 @@ -2604,18 +3283,154 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } return mapAntigravityModel(account, requestedModel) != "" } + if account.Platform == PlatformSora { + return s.isSoraModelSupportedByAccount(account, requestedModel) + } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { requestedModel = claude.NormalizeModelID(requestedModel) } - // Gemini API Key 账户直接透传,由上游判断模型是否支持 - if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey { - return true - } // 其他平台使用账户的模型支持检查 return account.IsModelSupported(requestedModel) } +func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool { + if account == nil { + return false + } + if strings.TrimSpace(requestedModel) == "" { + return true + } + + // 先走原始精确/通配符匹配。 + mapping := account.GetModelMapping() + if len(mapping) == 0 || account.IsModelSupported(requestedModel) { + return true + } + + aliases := buildSoraModelAliases(requestedModel) + if len(aliases) == 0 { + return false + } + + hasSoraSelector := false + for pattern := range mapping { + if !isSoraModelSelector(pattern) { + continue + } + hasSoraSelector = true + if matchPatternAnyAlias(pattern, aliases) { + return true + } + } + + // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*), + // 此时不应误拦截 Sora 模型请求。 + if !hasSoraSelector { + return true + } + + return false +} + +func matchPatternAnyAlias(pattern string, aliases []string) bool { + normalizedPattern := strings.ToLower(strings.TrimSpace(pattern)) + if normalizedPattern == "" { + return false + } + for _, alias := range aliases { + if matchWildcard(normalizedPattern, alias) { + return true + } + } + return false +} + +func isSoraModelSelector(pattern string) bool { + p := strings.ToLower(strings.TrimSpace(pattern)) + if p == "" { + return false + } + + switch { + case strings.HasPrefix(p, "sora"), + strings.HasPrefix(p, "gpt-image"), + strings.HasPrefix(p, "prompt-enhance"), + strings.HasPrefix(p, "sy_"): + return true + } + + return p == "video" || p == "image" +} + +func buildSoraModelAliases(requestedModel string) []string { + modelID := strings.ToLower(strings.TrimSpace(requestedModel)) + if modelID == "" { + return nil + } + + aliases := make([]string, 0, 8) + addAlias := func(value string) { + v := strings.ToLower(strings.TrimSpace(value)) + if v == "" { + return + } + for _, existing := range aliases { + if existing == v { + return + } + } + aliases = append(aliases, v) + } + + addAlias(modelID) + cfg, ok := GetSoraModelConfig(modelID) + if ok { + addAlias(cfg.Model) + switch cfg.Type { + case "video": + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case "image": + addAlias("image") + addAlias("gpt-image") + case "prompt_enhance": + addAlias("prompt-enhance") + } + return aliases + } + + switch { + case strings.HasPrefix(modelID, "sora"): + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case strings.HasPrefix(modelID, "gpt-image"): + addAlias("image") + addAlias("gpt-image") + case strings.HasPrefix(modelID, "prompt-enhance"): + addAlias("prompt-enhance") + default: + return nil + } + + return aliases +} + +func soraVideoFamilyAlias(modelID string) string { + switch { + case strings.HasPrefix(modelID, "sora2pro-hd"): + return "sora2pro-hd" + case strings.HasPrefix(modelID, "sora2pro"): + return "sora2pro" + case strings.HasPrefix(modelID, "sora2"): + return "sora2" + default: + return "" + } +} + // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { @@ -2767,6 +3582,60 @@ func hasClaudeCodePrefix(text string) bool { return false } +// matchesFilterPrefix 检查文本是否匹配任一过滤前缀 +func matchesFilterPrefix(text string) bool { + for _, prefix := range systemBlockFilterPrefixes { + if strings.HasPrefix(text, prefix) { + return true + } + } + return false +} + +// filterSystemBlocksByPrefix 从 body 的 system 中移除文本匹配 systemBlockFilterPrefixes 前缀的元素 +// 直接从 body 解析 system,不依赖外部传入的 parsed.System(因为前置步骤可能已修改 body 中的 system) +func filterSystemBlocksByPrefix(body []byte) []byte { + sys := gjson.GetBytes(body, "system") + if !sys.Exists() { + return body + } + + switch { + case sys.Type == gjson.String: + if matchesFilterPrefix(sys.Str) { + result, err := sjson.DeleteBytes(body, "system") + if err != nil { + return body + } + return result + } + case sys.IsArray(): + var parsed []any + if err := json.Unmarshal([]byte(sys.Raw), &parsed); err != nil { + return body + } + filtered := make([]any, 0, len(parsed)) + changed := false + for _, item := range parsed { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && matchesFilterPrefix(text) { + changed = true + continue + } + } + filtered = append(filtered, item) + } + if changed { + result, err := sjson.SetBytes(body, "system", filtered) + if err != nil { + return body + } + return result + } + } + return body +} + // injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 // 处理 null、字符串、数组三种格式 func injectClaudeCodePrompt(body []byte, system any) []byte { @@ -2825,7 +3694,7 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { result, err := sjson.SetBytes(body, "system", newSystem) if err != nil { - log.Printf("Warning: failed to inject Claude Code prompt: %v", err) + logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt: %v", err) return body } return result @@ -2981,7 +3850,7 @@ func removeCacheControlFromThinkingBlocks(data map[string]any) { if blockType, _ := m["type"].(string); blockType == "thinking" { if _, has := m["cache_control"]; has { delete(m, "cache_control") - log.Printf("[Warning] Removed illegal cache_control from thinking block in system") + logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in system") } } } @@ -2998,7 +3867,7 @@ func removeCacheControlFromThinkingBlocks(data map[string]any) { if blockType, _ := m["type"].(string); blockType == "thinking" { if _, has := m["cache_control"]; has { delete(m, "cache_control") - log.Printf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx) + logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx) } } } @@ -3016,6 +3885,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, fmt.Errorf("parse request: empty request") } + if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { + return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime) + } + body := parsed.Body reqModel := parsed.Model reqStream := parsed.Stream @@ -3046,6 +3919,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } + // OAuth/SetupToken 账号:移除黑名单前缀匹配的 system 元素(如客户端注入的计费元数据) + // 放在 inject/normalize 之后,确保不会被覆盖 + if account.IsOAuth() { + body = filterSystemBlocksByPrefix(body) + } + // 强制执行 cache_control 块数量限制(最多 4 个) body = enforceCacheControlLimit(body) @@ -3071,7 +3950,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 替换请求体中的模型名 body = s.replaceModelInBody(body, mappedModel) reqModel = mappedModel - log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) + logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) } // 获取凭证 @@ -3087,16 +3966,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 调试日志:记录即将转发的账号信息 - log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", + logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL) + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 + setOpsUpstreamRequestBody(c, body) // 重试循环 var resp *http.Response retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - // Capture upstream request body for ops retry of this attempt. - c.Set(OpsUpstreamRequestBodyKey, string(body)) upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if err != nil { return nil, err @@ -3167,7 +4046,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) break } - log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) // Conservative two-stage fallback: // 1) Disable thinking + thinking->text (preserve content) @@ -3180,7 +4059,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { if retryResp.StatusCode < 400 { - log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry succeeded (thinking downgraded)", account.ID) resp = retryResp break } @@ -3205,7 +4084,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }) msg2 := extractUpstreamErrorMessage(retryRespBody) if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { - log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if buildErr2 == nil { @@ -3225,9 +4104,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A Kind: "signature_retry_tools_request_error", Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), }) - log.Printf("Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2) + logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2) } else { - log.Printf("Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2) + logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2) } } } @@ -3243,9 +4122,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if retryResp != nil && retryResp.Body != nil { _ = retryResp.Body.Close() } - log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr) + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry failed: %v", account.ID, retryErr) } else { - log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr) + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry build request failed: %v", account.ID, buildErr) } // Retry failed: restore original response body and continue handling. @@ -3291,7 +4170,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - log.Printf("Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", + logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) if err := sleepWithContext(ctx, delay); err != nil { return nil, err @@ -3304,10 +4183,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 不需要重试(成功或不可重试的错误),跳出循环 // DEBUG: 输出响应 headers(用于检测 rate limit 信息) - if account.Platform == PlatformGemini && resp.StatusCode < 400 { - log.Printf("[DEBUG] Gemini API Response Headers for account %d:", account.ID) + if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID) for k, v := range resp.Header { - log.Printf("[DEBUG] %s: %v", k, v) + logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v) } } break @@ -3325,7 +4204,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) // 调试日志:打印重试耗尽后的错误响应 - log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) s.handleRetryExhaustedSideEffects(ctx, resp, account) @@ -3356,7 +4235,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) // 调试日志:打印上游错误响应 - log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) s.handleFailoverSideEffects(ctx, resp, account) @@ -3410,13 +4289,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }) if s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover: %s", account.ID, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), ) } else { - log.Printf("Account %d: 400 error, attempting failover", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover", account.ID) } s.handleFailoverSideEffects(ctx, resp, account) return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} @@ -3426,6 +4305,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 处理正常响应 + + // 触发上游接受回调(提前释放串行锁,不等流完成) + if parsed.OnUpstreamAccepted != nil { + parsed.OnUpstreamAccepted() + } + var usage *ClaudeUsage var firstTokenMs *int var clientDisconnect bool @@ -3460,6 +4345,602 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }, nil } +func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + reqModel string, + reqStream bool, + startTime time.Time, +) (*ForwardResult, error) { + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + if tokenType != "apikey" { + return nil, fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v", + account.ID, account.Name, reqModel, reqStream) + + if c != nil { + c.Set("anthropic_passthrough", true) + } + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 + setOpsUpstreamRequestBody(c, body) + + var resp *http.Response + retryStart := time.Now() + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) + if err != nil { + return nil, err + } + + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + + // 透传分支禁止 400 请求体降级重试(该重试会改写请求体) + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetryAttempts { + elapsed := time.Since(retryStart) + if elapsed >= maxRetryElapsed { + break + } + + delay := retryBackoffDelay(attempt) + remaining := maxRetryElapsed - elapsed + if delay > remaining { + delay = remaining + } + if delay <= 0 { + break + } + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "retry", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + logger.LegacyPrintf("service.gateway", "Anthropic passthrough account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", + account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + break + } + + break + } + if resp == nil || resp.Body == nil { + return nil, errors.New("upstream request failed: empty response") + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "retry_exhausted_failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + + if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + + if resp.StatusCode >= 400 { + return s.handleErrorResponse(ctx, resp, c, account) + } + + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + if reqStream { + streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect + } else { + usage, err = s.handleNonStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account) + if err != nil { + return nil, err + } + } + if usage == nil { + usage = &ClaudeUsage{} + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: reqModel, + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + }, nil +} + +func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := claudeAPIURL + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if !allowedHeaders[lowerKey] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // 覆盖入站鉴权残留,并注入上游认证 + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Del("cookie") + req.Header.Set("x-api-key", token) + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + + return req, nil +} + +func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, + model string, +) (*streamingResult, error) { + if s.rateLimitService != nil { + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + } + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "text/event-stream" + } + c.Header("Content-Type", contentType) + if c.Writer.Header().Get("Cache-Control") == "" { + c.Header("Cache-Control", "no-cache") + } + if c.Writer.Header().Get("Connection") == "" { + c.Header("Connection", "keep-alive") + } + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &ClaudeUsage{} + var firstTokenMs *int + clientDisconnected := false + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + for { + select { + case ev, ok := <-events: + if !ok { + if !clientDisconnected { + // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 + flusher.Flush() + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if ev.err != nil { + if clientDisconnected { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v", + account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err()) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + if errors.Is(ev.err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) + } + + line := ev.line + if data, ok := extractAnthropicSSEDataLine(line); ok { + trimmed := strings.TrimSpace(data) + if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsagePassthrough(data, usage) + } + + if !clientDisconnected { + if _, err := io.WriteString(w, line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else if _, err := io.WriteString(w, "\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else if line == "" { + // 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。 + flusher.Flush() + } + } + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, model) + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + } + } +} + +func extractAnthropicSSEDataLine(line string) (string, bool) { + if !strings.HasPrefix(line, "data:") { + return "", false + } + start := len("data:") + for start < len(line) { + if line[start] != ' ' && line[start] != '\t' { + break + } + start++ + } + return line[start:], true +} + +func (s *GatewayService) parseSSEUsagePassthrough(data string, usage *ClaudeUsage) { + if usage == nil || data == "" || data == "[DONE]" { + return + } + + parsed := gjson.Parse(data) + switch parsed.Get("type").String() { + case "message_start": + msgUsage := parsed.Get("message.usage") + if msgUsage.Exists() { + usage.InputTokens = int(msgUsage.Get("input_tokens").Int()) + usage.CacheCreationInputTokens = int(msgUsage.Get("cache_creation_input_tokens").Int()) + usage.CacheReadInputTokens = int(msgUsage.Get("cache_read_input_tokens").Int()) + + // 保持与通用解析一致:message_start 允许覆盖 5m/1h 明细(包括 0)。 + cc5m := msgUsage.Get("cache_creation.ephemeral_5m_input_tokens") + cc1h := msgUsage.Get("cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + usage.CacheCreation5mTokens = int(cc5m.Int()) + usage.CacheCreation1hTokens = int(cc1h.Int()) + } + } + case "message_delta": + deltaUsage := parsed.Get("usage") + if deltaUsage.Exists() { + if v := deltaUsage.Get("input_tokens").Int(); v > 0 { + usage.InputTokens = int(v) + } + if v := deltaUsage.Get("output_tokens").Int(); v > 0 { + usage.OutputTokens = int(v) + } + if v := deltaUsage.Get("cache_creation_input_tokens").Int(); v > 0 { + usage.CacheCreationInputTokens = int(v) + } + if v := deltaUsage.Get("cache_read_input_tokens").Int(); v > 0 { + usage.CacheReadInputTokens = int(v) + } + + cc5m := deltaUsage.Get("cache_creation.ephemeral_5m_input_tokens") + cc1h := deltaUsage.Get("cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() && cc5m.Int() > 0 { + usage.CacheCreation5mTokens = int(cc5m.Int()) + } + if cc1h.Exists() && cc1h.Int() > 0 { + usage.CacheCreation1hTokens = int(cc1h.Int()) + } + } + } + + if usage.CacheReadInputTokens == 0 { + if cached := parsed.Get("message.usage.cached_tokens").Int(); cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + if cached := parsed.Get("usage.cached_tokens").Int(); usage.CacheReadInputTokens == 0 && cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + } + if usage.CacheCreationInputTokens == 0 { + cc5m := parsed.Get("message.usage.cache_creation.ephemeral_5m_input_tokens").Int() + cc1h := parsed.Get("message.usage.cache_creation.ephemeral_1h_input_tokens").Int() + if cc5m == 0 && cc1h == 0 { + cc5m = parsed.Get("usage.cache_creation.ephemeral_5m_input_tokens").Int() + cc1h = parsed.Get("usage.cache_creation.ephemeral_1h_input_tokens").Int() + } + total := cc5m + cc1h + if total > 0 { + usage.CacheCreationInputTokens = int(total) + } + } +} + +func parseClaudeUsageFromResponseBody(body []byte) *ClaudeUsage { + usage := &ClaudeUsage{} + if len(body) == 0 { + return usage + } + + parsed := gjson.ParseBytes(body) + usageNode := parsed.Get("usage") + if !usageNode.Exists() { + return usage + } + + usage.InputTokens = int(usageNode.Get("input_tokens").Int()) + usage.OutputTokens = int(usageNode.Get("output_tokens").Int()) + usage.CacheCreationInputTokens = int(usageNode.Get("cache_creation_input_tokens").Int()) + usage.CacheReadInputTokens = int(usageNode.Get("cache_read_input_tokens").Int()) + + cc5m := usageNode.Get("cache_creation.ephemeral_5m_input_tokens").Int() + cc1h := usageNode.Get("cache_creation.ephemeral_1h_input_tokens").Int() + if cc5m > 0 || cc1h > 0 { + usage.CacheCreation5mTokens = int(cc5m) + usage.CacheCreation1hTokens = int(cc1h) + } + if usage.CacheCreationInputTokens == 0 && (cc5m > 0 || cc1h > 0) { + usage.CacheCreationInputTokens = int(cc5m + cc1h) + } + if usage.CacheReadInputTokens == 0 { + if cached := usageNode.Get("cached_tokens").Int(); cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + } + return usage +} + +func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ClaudeUsage, error) { + if s.rateLimitService != nil { + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + } + + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + usage := parseClaudeUsageFromResponseBody(body) + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + return usage, nil +} + +func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { + if dst == nil || src == nil { + return + } + if filter != nil { + responseheaders.WriteFilteredHeaders(dst, src, filter) + return + } + if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { + dst.Set("Content-Type", v) + } + if v := strings.TrimSpace(src.Get("x-request-id")); v != "" { + dst.Set("x-request-id", v) + } +} + func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL @@ -3485,7 +4966,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 1. 获取或创建指纹(包含随机生成的ClientID) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if err != nil { - log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err) + logger.LegacyPrintf("service.gateway", "Warning: failed to get fingerprint for account %d: %v", account.ID, err) // 失败时降级为透传原始headers } else { fingerprint = fp @@ -3552,12 +5033,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // messages requests typically use only oauth + interleaved-thinking. // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - drop := map[string]struct{}{claude.BetaClaudeCode: {}} - req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, droppedBetasWithClaudeCodeSet)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := req.Header.Get("anthropic-beta") - req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader)) + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), defaultDroppedBetasSet)) } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) @@ -3632,7 +5112,8 @@ func requestNeedsBetaFeatures(body []byte) bool { if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { return true } - if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") { + thinkingType := gjson.GetBytes(body, "thinking.type").String() + if strings.EqualFold(thinkingType, "enabled") || strings.EqualFold(thinkingType, "adaptive") { return true } return false @@ -3710,6 +5191,64 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str return strings.Join(out, ",") } +// stripBetaTokens removes the given beta tokens from a comma-separated header value. +func stripBetaTokens(header string, tokens []string) string { + if header == "" || len(tokens) == 0 { + return header + } + return stripBetaTokensWithSet(header, buildBetaTokenSet(tokens)) +} + +func stripBetaTokensWithSet(header string, drop map[string]struct{}) string { + if header == "" || len(drop) == 0 { + return header + } + parts := strings.Split(header, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if _, ok := drop[p]; ok { + continue + } + out = append(out, p) + } + if len(out) == len(parts) { + return header // no change, avoid allocation + } + return strings.Join(out, ",") +} + +// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. +func droppedBetaSet(extra ...string) map[string]struct{} { + m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) + for t := range defaultDroppedBetasSet { + m[t] = struct{}{} + } + for _, t := range extra { + m[t] = struct{}{} + } + return m +} + +func buildBetaTokenSet(tokens []string) map[string]struct{} { + m := make(map[string]struct{}, len(tokens)) + for _, t := range tokens { + if t == "" { + continue + } + m[t] = struct{}{} + } + return m +} + +var ( + defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas) + droppedBetasWithClaudeCodeSet = droppedBetaSet(claude.BetaClaudeCode) +) + // applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. // This mirrors opencode-anthropic-auth behavior: do not trust downstream // headers when using Claude Code-scoped OAuth credentials. @@ -3756,33 +5295,33 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { } // Log for debugging - log.Printf("[SignatureCheck] Checking error message: %s", msg) + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Checking error message: %s", msg) // 检测signature相关的错误(更宽松的匹配) // 例如: "Invalid `signature` in `thinking` block", "***.signature" 等 if strings.Contains(msg, "signature") { - log.Printf("[SignatureCheck] Detected signature error") + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected signature error") return true } // 检测 thinking block 顺序/类型错误 // 例如: "Expected `thinking` or `redacted_thinking`, but found `text`" if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { - log.Printf("[SignatureCheck] Detected thinking block type error") + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block type error") return true } // 检测 thinking block 被修改的错误 // 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified" if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { - log.Printf("[SignatureCheck] Detected thinking block modification error") + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block modification error") return true } // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的) // 例如: "all messages must have non-empty content" if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") { - log.Printf("[SignatureCheck] Detected empty content error") + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error") return true } @@ -3790,7 +5329,7 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { } func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { - // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 + // 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。 // 默认保守:无法识别则不切换。 msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) if msg == "" { @@ -3839,11 +5378,25 @@ func extractUpstreamErrorMessage(body []byte) string { return gjson.GetBytes(body, "message").String() } +func isCountTokensUnsupported404(statusCode int, body []byte) bool { + if statusCode != http.StatusNotFound { + return false + } + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body))) + if msg == "" { + return false + } + if strings.Contains(msg, "/v1/messages/count_tokens") { + return true + } + return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found") +} + func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) // 调试日志:打印上游错误响应 - log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s", + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) @@ -3854,7 +5407,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { if v, ok := c.Get(claudeMimicDebugInfoKey); ok { if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { - log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s", resp.StatusCode, resp.Header.Get("x-request-id"), line, @@ -3894,7 +5447,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res // 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.gateway", "Upstream error %d (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, @@ -3995,10 +5548,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re // OAuth/Setup Token 账号的 403:标记账号异常 if account.IsOAuth() && statusCode == 403 { s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) - log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode) + logger.LegacyPrintf("service.gateway", "Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode) } else { // API Key 未配置错误码:不标记账号状态 - log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) + logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) } } @@ -4024,7 +5577,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { if v, ok := c.Get(claudeMimicDebugInfoKey); ok { if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { - log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s", resp.StatusCode, resp.Header.Get("x-request-id"), line, @@ -4053,7 +5606,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht }) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.gateway", "Upstream error %d retries_exhausted (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, @@ -4116,8 +5669,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } // 设置SSE响应头 @@ -4145,7 +5698,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) type scanEvent struct { line string @@ -4164,7 +5718,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -4175,7 +5730,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) @@ -4209,9 +5764,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http pendingEventLines := make([]string, 0, 4) - processSSEEvent := func(lines []string) ([]string, string, error) { + processSSEEvent := func(lines []string) ([]string, string, *sseUsagePatch, error) { if len(lines) == 0 { - return nil, "", nil + return nil, "", nil, nil } eventName := "" @@ -4228,11 +5783,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } if eventName == "error" { - return nil, dataLine, errors.New("have error in stream") + return nil, dataLine, nil, errors.New("have error in stream") } if dataLine == "" { - return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil + return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil, nil } if dataLine == "[DONE]" { @@ -4241,7 +5796,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil + return []string{block}, dataLine, nil, nil } var event map[string]any @@ -4252,25 +5807,43 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil + return []string{block}, dataLine, nil, nil } eventType, _ := event["type"].(string) if eventName == "" { eventName = eventType } + eventChanged := false // 兼容 Kimi cached_tokens → cache_read_input_tokens if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { - reconcileCachedTokens(u) + eventChanged = reconcileCachedTokens(u) || eventChanged } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { - reconcileCachedTokens(u) + eventChanged = reconcileCachedTokens(u) || eventChanged + } + } + + // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 + if account.IsCacheTTLOverrideEnabled() { + overrideTarget := account.GetCacheTTLOverrideTarget() + if eventType == "message_start" { + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged + } + } + } + if eventType == "message_delta" { + if u, ok := event["usage"].(map[string]any); ok { + eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged + } } } @@ -4278,10 +5851,21 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if msg, ok := event["message"].(map[string]any); ok { if model, ok := msg["model"].(string); ok && model == mappedModel { msg["model"] = originalModel + eventChanged = true } } } + usagePatch := s.extractSSEUsagePatch(event) + if !eventChanged { + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, usagePatch, nil + } + newData, err := json.Marshal(event) if err != nil { // 序列化失败,直接透传原始数据 @@ -4290,7 +5874,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil + return []string{block}, dataLine, usagePatch, nil } block := "" @@ -4298,7 +5882,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + string(newData) + "\n\n" - return []string{block}, string(newData), nil + return []string{block}, string(newData), usagePatch, nil } for { @@ -4311,17 +5895,17 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if ev.err != nil { // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - log.Printf("Context canceled during streaming, returning collected usage") + logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage if clientDisconnected { - log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err) + logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } // 客户端未断开,正常的错误处理 if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) sendErrorEvent("response_too_large") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } @@ -4336,7 +5920,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http continue } - outputBlocks, data, err := processSSEEvent(pendingEventLines) + outputBlocks, data, usagePatch, err := processSSEEvent(pendingEventLines) pendingEventLines = pendingEventLines[:0] if err != nil { if clientDisconnected { @@ -4349,7 +5933,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if !clientDisconnected { if _, werr := fmt.Fprint(w, block); werr != nil { clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing") break } flusher.Flush() @@ -4359,7 +5943,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - s.parseSSEUsage(data, usage) + if usagePatch != nil { + mergeSSEUsagePatch(usage, usagePatch) + } } } continue @@ -4374,10 +5960,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } if clientDisconnected { // 客户端已断开,上游也超时了,返回已收集的 usage - log.Printf("Upstream timeout after client disconnect, returning collected usage") + logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } - log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 if s.rateLimitService != nil { s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) @@ -4390,54 +5976,241 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { - // 解析message_start获取input tokens(标准Claude API格式) - var msgStart struct { - Type string `json:"type"` - Message struct { - Usage ClaudeUsage `json:"usage"` - } `json:"message"` - } - if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" { - usage.InputTokens = msgStart.Message.Usage.InputTokens - usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens - usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens + if usage == nil { + return } - // 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API) - var msgDelta struct { - Type string `json:"type"` - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CacheCreationInputTokens int `json:"cache_creation_input_tokens"` - CacheReadInputTokens int `json:"cache_read_input_tokens"` - } `json:"usage"` + var event map[string]any + if err := json.Unmarshal([]byte(data), &event); err != nil { + return } - if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" { - // message_delta 仅覆盖存在且非0的字段 - // 避免覆盖 message_start 中已有的值(如 input_tokens) - // Claude API 的 message_delta 通常只包含 output_tokens - if msgDelta.Usage.InputTokens > 0 { - usage.InputTokens = msgDelta.Usage.InputTokens + + if patch := s.extractSSEUsagePatch(event); patch != nil { + mergeSSEUsagePatch(usage, patch) + } +} + +type sseUsagePatch struct { + inputTokens int + hasInputTokens bool + outputTokens int + hasOutputTokens bool + cacheCreationInputTokens int + hasCacheCreationInput bool + cacheReadInputTokens int + hasCacheReadInput bool + cacheCreation5mTokens int + hasCacheCreation5m bool + cacheCreation1hTokens int + hasCacheCreation1h bool +} + +func (s *GatewayService) extractSSEUsagePatch(event map[string]any) *sseUsagePatch { + if len(event) == 0 { + return nil + } + + eventType, _ := event["type"].(string) + switch eventType { + case "message_start": + msg, _ := event["message"].(map[string]any) + usageObj, _ := msg["usage"].(map[string]any) + if len(usageObj) == 0 { + return nil } - if msgDelta.Usage.OutputTokens > 0 { - usage.OutputTokens = msgDelta.Usage.OutputTokens + + patch := &sseUsagePatch{} + patch.hasInputTokens = true + if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok { + patch.inputTokens = v } - if msgDelta.Usage.CacheCreationInputTokens > 0 { - usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens + patch.hasCacheCreationInput = true + if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok { + patch.cacheCreationInputTokens = v } - if msgDelta.Usage.CacheReadInputTokens > 0 { - usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens + patch.hasCacheReadInput = true + if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok { + patch.cacheReadInputTokens = v + } + if cc, ok := usageObj["cache_creation"].(map[string]any); ok { + if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists { + patch.cacheCreation5mTokens = v + patch.hasCacheCreation5m = true + } + if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists { + patch.cacheCreation1hTokens = v + patch.hasCacheCreation1h = true + } + } + return patch + + case "message_delta": + usageObj, _ := event["usage"].(map[string]any) + if len(usageObj) == 0 { + return nil + } + + patch := &sseUsagePatch{} + if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok && v > 0 { + patch.inputTokens = v + patch.hasInputTokens = true + } + if v, ok := parseSSEUsageInt(usageObj["output_tokens"]); ok && v > 0 { + patch.outputTokens = v + patch.hasOutputTokens = true + } + if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok && v > 0 { + patch.cacheCreationInputTokens = v + patch.hasCacheCreationInput = true + } + if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok && v > 0 { + patch.cacheReadInputTokens = v + patch.hasCacheReadInput = true + } + if cc, ok := usageObj["cache_creation"].(map[string]any); ok { + if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists && v > 0 { + patch.cacheCreation5mTokens = v + patch.hasCacheCreation5m = true + } + if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists && v > 0 { + patch.cacheCreation1hTokens = v + patch.hasCacheCreation1h = true + } + } + return patch + } + + return nil +} + +func mergeSSEUsagePatch(usage *ClaudeUsage, patch *sseUsagePatch) { + if usage == nil || patch == nil { + return + } + + if patch.hasInputTokens { + usage.InputTokens = patch.inputTokens + } + if patch.hasCacheCreationInput { + usage.CacheCreationInputTokens = patch.cacheCreationInputTokens + } + if patch.hasCacheReadInput { + usage.CacheReadInputTokens = patch.cacheReadInputTokens + } + if patch.hasOutputTokens { + usage.OutputTokens = patch.outputTokens + } + if patch.hasCacheCreation5m { + usage.CacheCreation5mTokens = patch.cacheCreation5mTokens + } + if patch.hasCacheCreation1h { + usage.CacheCreation1hTokens = patch.cacheCreation1hTokens + } +} + +func parseSSEUsageInt(value any) (int, bool) { + switch v := value.(type) { + case float64: + return int(v), true + case float32: + return int(v), true + case int: + return v, true + case int64: + return int(v), true + case int32: + return int(v), true + case json.Number: + if i, err := v.Int64(); err == nil { + return int(i), true + } + if f, err := v.Float64(); err == nil { + return int(f), true + } + case string: + if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil { + return parsed, true } } + return 0, false +} + +// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。 +// target 为 "5m" 或 "1h"。返回 true 表示发生了变更。 +func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool { + // Fallback: 如果只有聚合字段但无 5m/1h 明细,将聚合字段归入 5m 默认类别 + if usage.CacheCreation5mTokens == 0 && usage.CacheCreation1hTokens == 0 && usage.CacheCreationInputTokens > 0 { + usage.CacheCreation5mTokens = usage.CacheCreationInputTokens + } + + total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens + if total == 0 { + return false + } + switch target { + case "1h": + if usage.CacheCreation1hTokens == total { + return false // 已经全是 1h + } + usage.CacheCreation1hTokens = total + usage.CacheCreation5mTokens = 0 + default: // "5m" + if usage.CacheCreation5mTokens == total { + return false // 已经全是 5m + } + usage.CacheCreation5mTokens = total + usage.CacheCreation1hTokens = 0 + } + return true +} + +// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。 +// usageObj 是 usage JSON 对象(map[string]any)。 +func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool { + ccObj, ok := usageObj["cache_creation"].(map[string]any) + if !ok { + return false + } + v5m, _ := parseSSEUsageInt(ccObj["ephemeral_5m_input_tokens"]) + v1h, _ := parseSSEUsageInt(ccObj["ephemeral_1h_input_tokens"]) + total := v5m + v1h + if total == 0 { + return false + } + switch target { + case "1h": + if v1h == total { + return false + } + ccObj["ephemeral_1h_input_tokens"] = float64(total) + ccObj["ephemeral_5m_input_tokens"] = float64(0) + default: // "5m" + if v5m == total { + return false + } + ccObj["ephemeral_5m_input_tokens"] = float64(total) + ccObj["ephemeral_1h_input_tokens"] = float64(0) + } + return true } func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - body, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } @@ -4449,6 +6222,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h return nil, fmt.Errorf("parse response: %w", err) } + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens") + cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + response.Usage.CacheCreation5mTokens = int(cc5m.Int()) + response.Usage.CacheCreation1hTokens = int(cc1h.Int()) + } + // 兼容 Kimi cached_tokens → cache_read_input_tokens if response.Usage.CacheReadInputTokens == 0 { cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() @@ -4460,12 +6241,26 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } + // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 + if account.IsCacheTTLOverrideEnabled() { + overrideTarget := account.GetCacheTTLOverrideTarget() + if applyCacheTTLOverride(&response.Usage, overrideTarget) { + // 同步更新 body JSON 中的嵌套 cache_creation 对象 + if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil { + body = newBody + } + if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens", response.Usage.CacheCreation1hTokens); err == nil { + body = newBody + } + } + } + // 如果有模型映射,替换响应中的model字段 if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json" if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { @@ -4481,24 +6276,76 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } // replaceModelInResponseBody 替换响应体中的model字段 +// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化 func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return body + if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { + newBody, err := sjson.SetBytes(body, "model", toModel) + if err != nil { + return body + } + return newBody + } + return body +} + +func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { + if s == nil || userID <= 0 || groupID <= 0 { + return groupDefaultMultiplier } - model, ok := resp["model"].(string) - if !ok || model != fromModel { - return body + key := fmt.Sprintf("%d:%d", userID, groupID) + if s.userGroupRateCache != nil { + if cached, ok := s.userGroupRateCache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier + } + } } + if s.userGroupRateRepo == nil { + return groupDefaultMultiplier + } + userGroupRateCacheMissTotal.Add(1) - resp["model"] = toModel - newBody, err := json.Marshal(resp) + value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) { + if s.userGroupRateCache != nil { + if cached, ok := s.userGroupRateCache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier, nil + } + } + } + + userGroupRateCacheLoadTotal.Add(1) + userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID) + if repoErr != nil { + return nil, repoErr + } + multiplier := groupDefaultMultiplier + if userRate != nil { + multiplier = *userRate + } + if s.userGroupRateCache != nil { + s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg)) + } + return multiplier, nil + }) + if shared { + userGroupRateCacheSFSharedTotal.Add(1) + } if err != nil { - return body + userGroupRateCacheFallbackTotal.Add(1) + logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err) + return groupDefaultMultiplier } - return newBody + multiplier, ok := value.(float64) + if !ok { + userGroupRateCacheFallbackTotal.Add(1) + return groupDefaultMultiplier + } + return multiplier } // RecordUsageInput 记录使用量的输入参数 @@ -4530,29 +6377,50 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens // 用于粘性会话切换时的特殊计费处理 if input.ForceCacheBilling && result.Usage.InputTokens > 0 { - log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", + logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", result.Usage.InputTokens, account.ID) result.Usage.CacheReadInputTokens += result.Usage.InputTokens result.Usage.InputTokens = 0 } - // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) - multiplier := s.cfg.Default.RateMultiplier - if apiKey.GroupID != nil && apiKey.Group != nil { - multiplier = apiKey.Group.RateMultiplier + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() { + applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) + cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 + } - // 检查用户专属倍率 - if s.userGroupRateRepo != nil { - if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { - multiplier = *userRate - } - } + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } + if apiKey.GroupID != nil && apiKey.Group != nil { + groupDefault := apiKey.Group.RateMultiplier + multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } var cost *CostBreakdown // 根据请求类型选择计费方式 - if result.ImageCount > 0 { + if result.MediaType == "image" || result.MediaType == "video" { + var soraConfig *SoraPriceConfig + if apiKey.Group != nil { + soraConfig = &SoraPriceConfig{ + ImagePrice360: apiKey.Group.SoraImagePrice360, + ImagePrice540: apiKey.Group.SoraImagePrice540, + VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, + } + } + if result.MediaType == "image" { + cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) + } else { + cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) + } + } else if result.MediaType == "prompt" { + cost = &CostBreakdown{} + } else if result.ImageCount > 0 { // 图片生成计费 var groupConfig *ImagePriceConfig if apiKey.Group != nil { @@ -4566,15 +6434,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } else { // Token 计费 tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) if err != nil { - log.Printf("Calculate cost failed: %v", err) + logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} } } @@ -4592,6 +6462,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu if result.ImageSize != "" { imageSize = &result.ImageSize } + var mediaType *string + if strings.TrimSpace(result.MediaType) != "" { + mediaType = &result.MediaType + } accountRateMultiplier := account.BillingRateMultiplier() usageLog := &UsageLog{ UserID: user.ID, @@ -4603,6 +6477,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, InputCost: cost.InputCost, OutputCost: cost.OutputCost, CacheCreationCost: cost.CacheCreationCost, @@ -4617,6 +6493,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: imageSize, + MediaType: mediaType, + CacheTTLOverridden: cacheTTLOverridden, CreatedAt: time.Now(), } @@ -4640,11 +6518,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu inserted, err := s.usageLogRepo.Create(ctx, usageLog) if err != nil { - log.Printf("Create usage log failed: %v", err) + logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } @@ -4656,7 +6534,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) if shouldBill && cost.TotalCost > 0 { if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - log.Printf("Increment subscription usage failed: %v", err) + logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) } // 异步更新订阅缓存 s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) @@ -4665,7 +6543,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) if shouldBill && cost.ActualCost > 0 { if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - log.Printf("Deduct balance failed: %v", err) + logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) } // 异步更新余额缓存 s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) @@ -4675,7 +6553,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 更新 API Key 配额(如果设置了配额限制) if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - log.Printf("Update API key quota failed: %v", err) + logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err) } } @@ -4711,23 +6589,27 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens // 用于粘性会话切换时的特殊计费处理 if input.ForceCacheBilling && result.Usage.InputTokens > 0 { - log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", + logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", result.Usage.InputTokens, account.ID) result.Usage.CacheReadInputTokens += result.Usage.InputTokens result.Usage.InputTokens = 0 } - // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) - multiplier := s.cfg.Default.RateMultiplier - if apiKey.GroupID != nil && apiKey.Group != nil { - multiplier = apiKey.Group.RateMultiplier + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() { + applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) + cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 + } - // 检查用户专属倍率 - if s.userGroupRateRepo != nil { - if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { - multiplier = *userRate - } - } + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } + if apiKey.GroupID != nil && apiKey.Group != nil { + groupDefault := apiKey.Group.RateMultiplier + multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } var cost *CostBreakdown @@ -4747,15 +6629,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } else { // Token 计费(使用长上下文计费方法) tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) if err != nil { - log.Printf("Calculate cost failed: %v", err) + logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} } } @@ -4784,6 +6668,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, InputCost: cost.InputCost, OutputCost: cost.OutputCost, CacheCreationCost: cost.CacheCreationCost, @@ -4798,6 +6684,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: imageSize, + CacheTTLOverridden: cacheTTLOverridden, CreatedAt: time.Now(), } @@ -4821,11 +6708,11 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * inserted, err := s.usageLogRepo.Create(ctx, usageLog) if err != nil { - log.Printf("Create usage log failed: %v", err) + logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } @@ -4837,7 +6724,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) if shouldBill && cost.TotalCost > 0 { if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - log.Printf("Increment subscription usage failed: %v", err) + logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) } // 异步更新订阅缓存 s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) @@ -4846,14 +6733,14 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) if shouldBill && cost.ActualCost > 0 { if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - log.Printf("Deduct balance failed: %v", err) + logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) } // 异步更新余额缓存 s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) // API Key 独立配额扣费 if input.APIKeyService != nil && apiKey.Quota > 0 { if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - log.Printf("Add API key quota used failed: %v", err) + logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err) } } } @@ -4873,6 +6760,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return fmt.Errorf("parse request: empty request") } + if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { + return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body) + } + body := parsed.Body reqModel := parsed.Model @@ -4884,9 +6775,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } - // Antigravity 账户不支持 count_tokens 转发,直接返回空值 + // Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。 + // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。 if account.Platform == PlatformAntigravity { - c.JSON(http.StatusOK, gin.H{"input_tokens": 0}) + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform") return nil } @@ -4912,7 +6804,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, if mappedModel != reqModel { body = s.replaceModelInBody(body, mappedModel) reqModel = mappedModel - log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) + logger.LegacyPrintf("service.gateway", "CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) } } @@ -4945,16 +6837,22 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 读取响应体 - respBody, err := io.ReadAll(resp.Body) + maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) _ = resp.Body.Close() if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } // 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks) if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) { - log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) filteredBody := FilterThinkingBlocksForRetry(body) retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode) @@ -4962,9 +6860,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { resp = retryResp - respBody, err = io.ReadAll(resp.Body) + respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) _ = resp.Body.Close() if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } @@ -4991,7 +6894,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, // 记录上游错误摘要便于排障(不回显请求内容) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.gateway", "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, @@ -5021,6 +6924,170 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return nil } +func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx context.Context, c *gin.Context, account *Account, body []byte) error { + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") + return err + } + if tokenType != "apikey" { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Invalid account token type") + return fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType) + } + + upstreamReq, err := s.buildCountTokensRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) + if err != nil { + s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + return err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: sanitizeUpstreamErrorMessage(err.Error()), + }) + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") + return fmt.Errorf("upstream request failed: %w", err) + } + + maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) + _ = resp.Body.Close() + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + return err + } + + if resp.StatusCode >= 400 { + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + // 中转站不支持 count_tokens 端点时(404),返回 404 让客户端 fallback 到本地估算。 + // 仅在错误消息明确指向 count_tokens endpoint 不存在时生效,避免误吞其他 404(如错误 base_url)。 + // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。 + if isCountTokensUnsupported404(resp.StatusCode, respBody) { + logger.LegacyPrintf("service.gateway", + "[count_tokens] Upstream does not support count_tokens (404), returning 404: account=%d name=%s msg=%s", + account.ID, account.Name, truncateString(upstreamMsg, 512)) + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported by upstream") + return nil + } + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + errMsg := "Upstream request failed" + switch resp.StatusCode { + case 429: + errMsg = "Rate limit exceeded" + case 529: + errMsg = "Service overloaded" + } + s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, respBody) + return nil +} + +func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := claudeAPICountTokensURL + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages/count_tokens" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if !allowedHeaders[lowerKey] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Del("cookie") + req.Header.Set("x-api-key", token) + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + + return req, nil +} + // buildCountTokensRequest 构建 count_tokens 上游请求 func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) { // 确定目标 URL @@ -5103,7 +7170,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con incomingBeta := req.Header.Get("anthropic-beta") requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} - req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta)) + drop := droppedBetaSet() + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) } else { clientBetaHeader := req.Header.Get("anthropic-beta") if clientBetaHeader == "" { @@ -5113,7 +7181,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if !strings.Contains(beta, claude.BetaTokenCounting) { beta = beta + "," + claude.BetaTokenCounting } - req.Header.Set("anthropic-beta", beta) + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, defaultDroppedBetasSet)) } } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { @@ -5165,30 +7233,20 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { return normalized, nil } -// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内 -func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error { - scope, ok := ResolveAntigravityQuotaScope(requestedModel) - if !ok { - return nil // 无法解析 scope,跳过检查 - } - - group, err := s.resolveGroupByID(ctx, groupID) - if err != nil { - return nil // 查询失败时放行 - } - if group == nil { - return nil // 分组不存在时放行 - } - - if !IsScopeSupported(group.SupportedModelScopes, scope) { - return ErrModelScopeNotSupported - } - return nil -} - // GetAvailableModels returns the list of models available for a group // It aggregates model_mapping keys from all schedulable accounts in the group func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { + cacheKey := modelsListCacheKey(groupID, platform) + if s.modelsListCache != nil { + if cached, found := s.modelsListCache.Get(cacheKey); found { + if models, ok := cached.([]string); ok { + modelsListCacheHitTotal.Add(1) + return cloneStringSlice(models) + } + } + } + modelsListCacheMissTotal.Add(1) + var accounts []Account var err error @@ -5229,6 +7287,10 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, // If no account has model_mapping, return nil (use default) if !hasAnyMapping { + if s.modelsListCache != nil { + s.modelsListCache.Set(cacheKey, []string(nil), s.modelsListCacheTTL) + modelsListCacheStoreTotal.Add(1) + } return nil } @@ -5237,8 +7299,45 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, for model := range modelSet { models = append(models, model) } + sort.Strings(models) - return models + if s.modelsListCache != nil { + s.modelsListCache.Set(cacheKey, cloneStringSlice(models), s.modelsListCacheTTL) + modelsListCacheStoreTotal.Add(1) + } + return cloneStringSlice(models) +} + +func (s *GatewayService) InvalidateAvailableModelsCache(groupID *int64, platform string) { + if s == nil || s.modelsListCache == nil { + return + } + + normalizedPlatform := strings.TrimSpace(platform) + // 完整匹配时精准失效;否则按维度批量失效。 + if groupID != nil && normalizedPlatform != "" { + s.modelsListCache.Delete(modelsListCacheKey(groupID, normalizedPlatform)) + return + } + + targetGroup := derefGroupID(groupID) + for key := range s.modelsListCache.Items() { + parts := strings.SplitN(key, "|", 2) + if len(parts) != 2 { + continue + } + groupPart, parseErr := strconv.ParseInt(parts[0], 10, 64) + if parseErr != nil { + continue + } + if groupID != nil && groupPart != targetGroup { + continue + } + if normalizedPlatform != "" && parts[1] != normalizedPlatform { + continue + } + s.modelsListCache.Delete(key) + } } // reconcileCachedTokens 兼容 Kimi 等上游: diff --git a/backend/internal/service/gateway_service_benchmark_test.go b/backend/internal/service/gateway_service_benchmark_test.go index f15a85d6..c9c4d3dd 100644 --- a/backend/internal/service/gateway_service_benchmark_test.go +++ b/backend/internal/service/gateway_service_benchmark_test.go @@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") if err != nil { b.Fatalf("解析请求失败: %v", err) } diff --git a/backend/internal/service/gateway_service_selection_failure_stats_test.go b/backend/internal/service/gateway_service_selection_failure_stats_test.go new file mode 100644 index 00000000..743d70bb --- /dev/null +++ b/backend/internal/service/gateway_service_selection_failure_stats_test.go @@ -0,0 +1,141 @@ +package service + +import ( + "context" + "strings" + "testing" + "time" +) + +func TestCollectSelectionFailureStats(t *testing.T) { + svc := &GatewayService{} + model := "sora2-landscape-10s" + resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339) + + accounts := []Account{ + // excluded + { + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + }, + // unschedulable + { + ID: 2, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: false, + }, + // platform filtered + { + ID: 3, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + }, + // model unsupported + { + ID: 4, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-image": "gpt-image", + }, + }, + }, + // model rate limited + { + ID: 5, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + model: map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + }, + // eligible + { + ID: 6, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + }, + } + + excluded := map[int64]struct{}{1: {}} + stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false) + + if stats.Total != 6 { + t.Fatalf("total=%d want=6", stats.Total) + } + if stats.Excluded != 1 { + t.Fatalf("excluded=%d want=1", stats.Excluded) + } + if stats.Unschedulable != 1 { + t.Fatalf("unschedulable=%d want=1", stats.Unschedulable) + } + if stats.PlatformFiltered != 1 { + t.Fatalf("platform_filtered=%d want=1", stats.PlatformFiltered) + } + if stats.ModelUnsupported != 1 { + t.Fatalf("model_unsupported=%d want=1", stats.ModelUnsupported) + } + if stats.ModelRateLimited != 1 { + t.Fatalf("model_rate_limited=%d want=1", stats.ModelRateLimited) + } + if stats.Eligible != 1 { + t.Fatalf("eligible=%d want=1", stats.Eligible) + } +} + +func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) { + svc := &GatewayService{} + acc := &Account{ + ID: 7, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: false, + } + + diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false) + if diagnosis.Category != "unschedulable" { + t.Fatalf("category=%s want=unschedulable", diagnosis.Category) + } + if diagnosis.Detail != "schedulable=false" { + t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail) + } +} + +func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) { + svc := &GatewayService{} + model := "sora2-landscape-10s" + resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339) + acc := &Account{ + ID: 8, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + model: map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + } + + diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false) + if diagnosis.Category != "model_rate_limited" { + t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category) + } + if !strings.Contains(diagnosis.Detail, "remaining=") { + t.Fatalf("detail=%s want contains remaining=", diagnosis.Detail) + } +} diff --git a/backend/internal/service/gateway_service_sora_model_support_test.go b/backend/internal/service/gateway_service_sora_model_support_test.go new file mode 100644 index 00000000..8ee2a960 --- /dev/null +++ b/backend/internal/service/gateway_service_sora_model_support_test.go @@ -0,0 +1,79 @@ +package service + +import "testing" + +func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{}, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected sora model to be supported when model_mapping is empty") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-4o": "gpt-4o", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected sora model to be supported when mapping has no sora selectors") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "sora2": "sora2", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") { + t.Fatalf("expected family selector sora2 to support sora2-landscape-15s") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "sy_8": "sy_8", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-image": "gpt-image", + }, + }, + } + + if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image") + } +} diff --git a/backend/internal/service/gateway_service_sora_scheduling_test.go b/backend/internal/service/gateway_service_sora_scheduling_test.go new file mode 100644 index 00000000..5178e68e --- /dev/null +++ b/backend/internal/service/gateway_service_sora_scheduling_test.go @@ -0,0 +1,89 @@ +package service + +import ( + "context" + "testing" + "time" +) + +func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) { + svc := &GatewayService{} + now := time.Now() + past := now.Add(-1 * time.Minute) + future := now.Add(5 * time.Minute) + + acc := &Account{ + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + AutoPauseOnExpired: true, + ExpiresAt: &past, + OverloadUntil: &future, + RateLimitResetAt: &future, + } + + if !svc.isAccountSchedulableForSelection(acc) { + t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows") + } +} + +func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) { + svc := &GatewayService{} + future := time.Now().Add(5 * time.Minute) + + acc := &Account{ + Platform: PlatformAnthropic, + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &future, + } + + if svc.isAccountSchedulableForSelection(acc) { + t.Fatalf("expected non-sora account to keep generic schedulable checks") + } +} + +func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) { + svc := &GatewayService{} + model := "sora2-landscape-10s" + resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339) + globalResetAt := time.Now().Add(2 * time.Minute) + + acc := &Account{ + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &globalResetAt, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + model: map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + } + + if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) { + t.Fatalf("expected sora account to be blocked by model scope rate limit") + } +} + +func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) { + svc := &GatewayService{} + future := time.Now().Add(3 * time.Minute) + + accounts := []Account{ + { + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &future, + }, + } + + stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false) + if stats.Unschedulable != 0 || stats.Eligible != 1 { + t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible) + } +} diff --git a/backend/internal/service/gateway_service_streaming_test.go b/backend/internal/service/gateway_service_streaming_test.go new file mode 100644 index 00000000..c8803d39 --- /dev/null +++ b/backend/internal/service/gateway_service_streaming_test.go @@ -0,0 +1,52 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + + svc := &GatewayService{ + cfg: cfg, + rateLimitService: &RateLimitService{}, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + // Minimal SSE event to trigger parseSSEUsage + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":3}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 3, result.usage.InputTokens) + require.Equal(t, 7, result.usage.OutputTokens) +} diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go new file mode 100644 index 00000000..cd690cbd --- /dev/null +++ b/backend/internal/service/gateway_streaming_test.go @@ -0,0 +1,219 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// --- parseSSEUsage 测试 --- + +func newMinimalGatewayService() *GatewayService { + return &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } +} + +func TestParseSSEUsage_MessageStart(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + data := `{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation_input_tokens":50,"cache_read_input_tokens":200}}}` + svc.parseSSEUsage(data, usage) + + require.Equal(t, 100, usage.InputTokens) + require.Equal(t, 50, usage.CacheCreationInputTokens) + require.Equal(t, 200, usage.CacheReadInputTokens) + require.Equal(t, 0, usage.OutputTokens, "message_start 不应设置 output_tokens") +} + +func TestParseSSEUsage_MessageDelta(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + data := `{"type":"message_delta","usage":{"output_tokens":42}}` + svc.parseSSEUsage(data, usage) + + require.Equal(t, 42, usage.OutputTokens) + require.Equal(t, 0, usage.InputTokens, "message_delta 的 output_tokens 不应影响已有的 input_tokens") +} + +func TestParseSSEUsage_DeltaDoesNotOverwriteStartValues(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 先处理 message_start + svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100}}}`, usage) + require.Equal(t, 100, usage.InputTokens) + + // 再处理 message_delta(output_tokens > 0, input_tokens = 0) + svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":50}}`, usage) + require.Equal(t, 100, usage.InputTokens, "delta 中 input_tokens=0 不应覆盖 start 中的值") + require.Equal(t, 50, usage.OutputTokens) +} + +func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // GLM 等 API 会在 delta 中包含所有 usage 信息 + svc.parseSSEUsage(`{"type":"message_delta","usage":{"input_tokens":200,"output_tokens":100,"cache_creation_input_tokens":30,"cache_read_input_tokens":60}}`, usage) + require.Equal(t, 200, usage.InputTokens) + require.Equal(t, 100, usage.OutputTokens) + require.Equal(t, 30, usage.CacheCreationInputTokens) + require.Equal(t, 60, usage.CacheReadInputTokens) +} + +func TestParseSSEUsage_DeltaDoesNotResetCacheCreationBreakdown(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 先在 message_start 中写入非零 5m/1h 明细 + svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation":{"ephemeral_5m_input_tokens":30,"ephemeral_1h_input_tokens":70}}}}`, usage) + require.Equal(t, 30, usage.CacheCreation5mTokens) + require.Equal(t, 70, usage.CacheCreation1hTokens) + + // 后续 delta 带默认 0,不应覆盖已有非零值 + svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":12,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0}}}`, usage) + require.Equal(t, 30, usage.CacheCreation5mTokens, "delta 的 0 值不应重置 5m 明细") + require.Equal(t, 70, usage.CacheCreation1hTokens, "delta 的 0 值不应重置 1h 明细") + require.Equal(t, 12, usage.OutputTokens) +} + +func TestParseSSEUsage_InvalidJSON(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 无效 JSON 不应 panic + svc.parseSSEUsage("not json", usage) + require.Equal(t, 0, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) +} + +func TestParseSSEUsage_UnknownType(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 不是 message_start 或 message_delta 的类型 + svc.parseSSEUsage(`{"type":"content_block_delta","delta":{"text":"hello"}}`, usage) + require.Equal(t, 0, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) +} + +func TestParseSSEUsage_EmptyString(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + svc.parseSSEUsage("", usage) + require.Equal(t, 0, usage.InputTokens) +} + +func TestParseSSEUsage_DoneEvent(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // [DONE] 事件不应影响 usage + svc.parseSSEUsage("[DONE]", usage) + require.Equal(t, 0, usage.InputTokens) +} + +// --- 流式响应端到端测试 --- + +func TestHandleStreamingResponse_CacheTokens(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":10,\"cache_creation_input_tokens\":20,\"cache_read_input_tokens\":30}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":15}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 10, result.usage.InputTokens) + require.Equal(t, 15, result.usage.OutputTokens) + require.Equal(t, 20, result.usage.CacheCreationInputTokens) + require.Equal(t, 30, result.usage.CacheReadInputTokens) +} + +func TestHandleStreamingResponse_EmptyStream(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + // 直接关闭,不发送任何事件 + _ = pw.Close() + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) +} + +func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + // 包含特殊字符的 content_block_delta(引号、换行、Unicode) + _, _ = pw.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello \\\"world\\\"\\n你好\"}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":3}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 5, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + + // 验证响应中包含转发的数据 + body := rec.Body.String() + require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件") +} diff --git a/backend/internal/service/gateway_waiting_queue_test.go b/backend/internal/service/gateway_waiting_queue_test.go new file mode 100644 index 00000000..0c53323e --- /dev/null +++ b/backend/internal/service/gateway_waiting_queue_test.go @@ -0,0 +1,120 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestDecrementWaitCount_NilCache 确保 nil cache 不会 panic +func TestDecrementWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + // 不应 panic + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestDecrementWaitCount_CacheError 确保 cache 错误不会传播 +func TestDecrementWaitCount_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{} + svc := NewConcurrencyService(cache) + // DecrementWaitCount 使用 background context,错误只记录日志不传播 + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestDecrementAccountWaitCount_NilCache 确保 nil cache 不会 panic +func TestDecrementAccountWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + svc.DecrementAccountWaitCount(context.Background(), 1) +} + +// TestDecrementAccountWaitCount_CacheError 确保 cache 错误不会传播 +func TestDecrementAccountWaitCount_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{} + svc := NewConcurrencyService(cache) + svc.DecrementAccountWaitCount(context.Background(), 1) +} + +// TestWaitingQueueFlow_IncrementThenDecrement 测试完整的等待队列增减流程 +func TestWaitingQueueFlow_IncrementThenDecrement(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + // 进入等待队列 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed) + + // 离开等待队列(不应 panic) + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestWaitingQueueFlow_AccountLevel 测试账号级等待队列流程 +func TestWaitingQueueFlow_AccountLevel(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + // 进入账号等待队列 + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 42, 10) + require.NoError(t, err) + require.True(t, allowed) + + // 离开账号等待队列 + svc.DecrementAccountWaitCount(context.Background(), 42) +} + +// TestWaitingQueueFull_Returns429Signal 测试等待队列满时返回 false +func TestWaitingQueueFull_Returns429Signal(t *testing.T) { + // waitAllowed=false 模拟队列已满 + cache := &stubConcurrencyCacheForTest{waitAllowed: false} + svc := NewConcurrencyService(cache) + + // 用户级等待队列满 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.False(t, allowed, "等待队列满时应返回 false(调用方根据此返回 429)") + + // 账号级等待队列满 + allowed, err = svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err) + require.False(t, allowed, "账号等待队列满时应返回 false") +} + +// TestWaitingQueue_FailOpen_OnCacheError 测试 Redis 故障时 fail-open +func TestWaitingQueue_FailOpen_OnCacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis connection refused")} + svc := NewConcurrencyService(cache) + + // 用户级:Redis 错误时允许通过 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err, "Redis 错误不应向调用方传播") + require.True(t, allowed, "Redis 故障时应 fail-open 放行") + + // 账号级:同样 fail-open + allowed, err = svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err, "Redis 错误不应向调用方传播") + require.True(t, allowed, "Redis 故障时应 fail-open 放行") +} + +// TestCalculateMaxWait_Scenarios 测试最大等待队列大小计算 +func TestCalculateMaxWait_Scenarios(t *testing.T) { + tests := []struct { + concurrency int + expected int + }{ + {5, 25}, // 5 + 20 + {10, 30}, // 10 + 20 + {1, 21}, // 1 + 20 + {0, 21}, // min(1) + 20 + {-1, 21}, // min(1) + 20 + {-10, 21}, // min(1) + 20 + {100, 120}, // 100 + 20 + } + for _, tt := range tests { + result := CalculateMaxWait(tt.concurrency) + require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency) + } +} diff --git a/backend/internal/service/gemini_error_policy_test.go b/backend/internal/service/gemini_error_policy_test.go new file mode 100644 index 00000000..2ce8793a --- /dev/null +++ b/backend/internal/service/gemini_error_policy_test.go @@ -0,0 +1,384 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// TestShouldFailoverGeminiUpstreamError — verifies the failover decision +// for the ErrorPolicyNone path (original logic preserved). +// --------------------------------------------------------------------------- + +func TestShouldFailoverGeminiUpstreamError(t *testing.T) { + svc := &GeminiMessagesCompatService{} + + tests := []struct { + name string + statusCode int + expected bool + }{ + {"401_failover", 401, true}, + {"403_failover", 403, true}, + {"429_failover", 429, true}, + {"529_failover", 529, true}, + {"500_failover", 500, true}, + {"502_failover", 502, true}, + {"503_failover", 503, true}, + {"400_no_failover", 400, false}, + {"404_no_failover", 404, false}, + {"422_no_failover", 422, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.shouldFailoverGeminiUpstreamError(tt.statusCode) + require.Equal(t, tt.expected, got) + }) + } +} + +// --------------------------------------------------------------------------- +// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works +// correctly for Gemini platform accounts (API Key type). +// --------------------------------------------------------------------------- + +func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + body []byte + expected ErrorPolicyResult + }{ + { + name: "gemini_apikey_custom_codes_hit", + account: &Account{ + ID: 100, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429), float64(500)}, + }, + }, + statusCode: 429, + body: []byte(`{"error":"rate limited"}`), + expected: ErrorPolicyMatched, + }, + { + name: "gemini_apikey_custom_codes_miss", + account: &Account{ + ID: 101, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 500, + body: []byte(`{"error":"internal"}`), + expected: ErrorPolicySkipped, + }, + { + name: "gemini_apikey_no_custom_codes_returns_none", + account: &Account{ + ID: 102, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + statusCode: 500, + body: []byte(`{"error":"internal"}`), + expected: ErrorPolicyNone, + }, + { + name: "gemini_apikey_temp_unschedulable_hit", + account: &Account{ + ID: 103, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded service`), + expected: ErrorPolicyTempUnscheduled, + }, + { + name: "gemini_custom_codes_override_temp_unschedulable", + account: &Account{ + ID: 104, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(503)}, + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded`), + expected: ErrorPolicyMatched, // custom codes take precedence + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &errorPolicyRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body) + require.Equal(t, tt.expected, result) + }) + } +} + +// --------------------------------------------------------------------------- +// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling +// paths produce the correct behavior for each ErrorPolicyResult. +// +// These tests simulate the inline error policy switch in handleClaudeCompat +// and forwardNativeGemini by calling the same methods in the same order. +// --------------------------------------------------------------------------- + +func TestGeminiErrorPolicyIntegration(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + account *Account + statusCode int + respBody []byte + expectFailover bool // expect UpstreamFailoverError + expectHandleError bool // expect handleGeminiUpstreamError to be called + expectShouldFailover bool // for None path, whether shouldFailover triggers + }{ + { + name: "custom_codes_matched_429_failover", + account: &Account{ + ID: 200, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 429, + respBody: []byte(`{"error":"rate limited"}`), + expectFailover: true, + expectHandleError: true, + }, + { + name: "custom_codes_skipped_500_no_failover", + account: &Account{ + ID: 201, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 500, + respBody: []byte(`{"error":"internal"}`), + expectFailover: false, + expectHandleError: false, + }, + { + name: "temp_unschedulable_matched_failover", + account: &Account{ + ID: 202, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + respBody: []byte(`overloaded`), + expectFailover: true, + expectHandleError: true, + }, + { + name: "no_policy_429_failover_via_shouldFailover", + account: &Account{ + ID: 203, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + statusCode: 429, + respBody: []byte(`{"error":"rate limited"}`), + expectFailover: true, + expectHandleError: true, + expectShouldFailover: true, + }, + { + name: "no_policy_400_no_failover", + account: &Account{ + ID: 204, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + statusCode: 400, + respBody: []byte(`{"error":"bad request"}`), + expectFailover: false, + expectHandleError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &geminiErrorPolicyRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + rateLimitService: rlSvc, + } + + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + // Simulate the Claude compat error handling path (same logic as native). + // This mirrors the inline switch in handleClaudeCompat. + var handleErrorCalled bool + var gotFailover bool + + ctx := context.Background() + statusCode := tt.statusCode + respBody := tt.respBody + account := tt.account + headers := http.Header{} + + if svc.rateLimitService != nil { + switch svc.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, respBody) { + case ErrorPolicySkipped: + // Skipped → return error directly (no handleGeminiUpstreamError, no failover) + gotFailover = false + handleErrorCalled = false + goto verify + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody) + handleErrorCalled = true + gotFailover = true + goto verify + } + } + + // ErrorPolicyNone → original logic + svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody) + handleErrorCalled = true + if svc.shouldFailoverGeminiUpstreamError(statusCode) { + gotFailover = true + } + + verify: + require.Equal(t, tt.expectFailover, gotFailover, "failover mismatch") + require.Equal(t, tt.expectHandleError, handleErrorCalled, "handleGeminiUpstreamError call mismatch") + + if tt.expectShouldFailover { + require.True(t, svc.shouldFailoverGeminiUpstreamError(statusCode), + "shouldFailoverGeminiUpstreamError should return true for status %d", statusCode) + } + }) + } +} + +// --------------------------------------------------------------------------- +// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety +// --------------------------------------------------------------------------- + +func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) { + svc := &GeminiMessagesCompatService{ + rateLimitService: nil, + } + + // When rateLimitService is nil, error policy is skipped → falls through to + // shouldFailoverGeminiUpstreamError (original logic). + // Verify this doesn't panic and follows expected behavior. + + ctx := context.Background() + account := &Account{ + ID: 300, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + } + + // The nil check should prevent CheckErrorPolicy from being called + if svc.rateLimitService != nil { + t.Fatal("rateLimitService should be nil for this test") + } + + // shouldFailoverGeminiUpstreamError still works + require.True(t, svc.shouldFailoverGeminiUpstreamError(429)) + require.False(t, svc.shouldFailoverGeminiUpstreamError(400)) + + // handleGeminiUpstreamError should not panic with nil rateLimitService + require.NotPanics(t, func() { + svc.handleGeminiUpstreamError(ctx, account, 500, http.Header{}, []byte(`error`)) + }) +} + +// --------------------------------------------------------------------------- +// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error +// policy tests. Embeds mockAccountRepoForGemini and adds tracking. +// --------------------------------------------------------------------------- + +type geminiErrorPolicyRepo struct { + mockAccountRepoForGemini + setErrorCalls int + setRateLimitedCalls int + setTempCalls int +} + +func (r *geminiErrorPolicyRepo) SetError(_ context.Context, _ int64, _ string) error { + r.setErrorCalls++ + return nil +} + +func (r *geminiErrorPolicyRepo) SetRateLimited(_ context.Context, _ int64, _ time.Time) error { + r.setRateLimitedCalls++ + return nil +} + +func (r *geminiErrorPolicyRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error { + r.setTempCalls++ + return nil +} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 0f156c2e..1c38b6c2 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -22,10 +22,12 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" ) const geminiStickySessionTTL = time.Hour @@ -51,6 +53,7 @@ type GeminiMessagesCompatService struct { httpUpstream HTTPUpstream antigravityGatewayService *AntigravityGatewayService cfg *config.Config + responseHeaderFilter *responseheaders.CompiledHeaderFilter } func NewGeminiMessagesCompatService( @@ -74,6 +77,7 @@ func NewGeminiMessagesCompatService( httpUpstream: httpUpstream, antigravityGatewayService: antigravityGatewayService, cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), } } @@ -227,6 +231,16 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest( account *Account, requestedModel, platform string, useMixedScheduling bool, +) bool { + return s.isAccountUsableForRequestWithPrecheck(ctx, account, requestedModel, platform, useMixedScheduling, nil) +} + +func (s *GeminiMessagesCompatService) isAccountUsableForRequestWithPrecheck( + ctx context.Context, + account *Account, + requestedModel, platform string, + useMixedScheduling bool, + precheckResult map[int64]bool, ) bool { // 检查模型调度能力 // Check model scheduling capability @@ -248,7 +262,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest( // 速率限制预检 // Rate limit precheck - if !s.passesRateLimitPreCheck(ctx, account, requestedModel) { + if !s.passesRateLimitPreCheckWithCache(ctx, account, requestedModel, precheckResult) { return false } @@ -270,18 +284,20 @@ func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account return false } -// passesRateLimitPreCheck 执行速率限制预检。 -// 返回 true 表示通过预检或无需预检。 -// -// passesRateLimitPreCheck performs rate limit precheck. -// Returns true if passed or precheck not required. -func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool { +func (s *GeminiMessagesCompatService) passesRateLimitPreCheckWithCache(ctx context.Context, account *Account, requestedModel string, precheckResult map[int64]bool) bool { if s.rateLimitService == nil || requestedModel == "" { return true } + + if precheckResult != nil { + if ok, exists := precheckResult[account.ID]; exists { + return ok + } + } + ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel) if err != nil { - log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err) } return ok } @@ -300,6 +316,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount( useMixedScheduling bool, ) *Account { var selected *Account + precheckResult := s.buildPreCheckUsageResultMap(ctx, accounts, requestedModel) for i := range accounts { acc := &accounts[i] @@ -310,7 +327,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount( } // 检查账号是否可用于当前请求 - if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) { + if !s.isAccountUsableForRequestWithPrecheck(ctx, acc, requestedModel, platform, useMixedScheduling, precheckResult) { continue } @@ -328,6 +345,23 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount( return selected } +func (s *GeminiMessagesCompatService) buildPreCheckUsageResultMap(ctx context.Context, accounts []Account, requestedModel string) map[int64]bool { + if s.rateLimitService == nil || requestedModel == "" || len(accounts) == 0 { + return nil + } + + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + candidates = append(candidates, &accounts[i]) + } + + result, err := s.rateLimitService.PreCheckUsageBatch(ctx, candidates, requestedModel) + if err != nil { + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheckBatch] failed: %v", err) + } + return result +} + // isBetterGeminiAccount 判断 candidate 是否比 current 更优。 // 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。 // @@ -560,10 +594,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -640,10 +671,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -703,7 +731,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex Message: safeErr, }) if attempt < geminiMaxRetries { - log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) sleepGeminiBackoff(attempt) continue } @@ -759,7 +787,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex } retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody) if txErr == nil { - log.Printf("Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName) geminiReq = retryGeminiReq // Consume one retry budget attempt and continue with the updated request payload. sleepGeminiBackoff(1) @@ -776,6 +804,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex break } + // 错误策略优先:匹配则跳过重试直接处理。 + if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched { + resp = rebuilt + break + } else { + resp = rebuilt + } + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() @@ -818,7 +854,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex Detail: upstreamDetail, }) - log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) sleepGeminiBackoff(attempt) continue } @@ -837,37 +873,77 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - tempMatched := false + // 统一错误策略:自定义错误码 + 临时不可调度 if s.rateLimitService != nil { - tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody) - } - s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - if tempMatched { - upstreamReqID := resp.Header.Get(requestIDHeader) - if upstreamReqID == "" { - upstreamReqID = resp.Header.Get("x-goog-request-id") - } - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 + switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) { + case ErrorPolicySkipped: + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") } - upstreamDetail = truncateString(string(respBody), maxBytes) + return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody) + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + } + + // ErrorPolicyNone → 原有逻辑 + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + // 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁 + if resp.StatusCode == http.StatusBadRequest { + msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if isGoogleProjectConfigError(msg400) { + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true} } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: upstreamReqID, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { upstreamReqID := resp.Header.Get(requestIDHeader) @@ -926,7 +1002,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if err != nil { return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream") } - claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel) + collectedBytes, _ := json.Marshal(collected) + claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel, collectedBytes) c.JSON(http.StatusOK, claudeResp) usage = usageObj2 if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) { @@ -1026,10 +1103,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -1097,10 +1171,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -1159,7 +1230,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. Message: safeErr, }) if attempt < geminiMaxRetries { - log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) sleepGeminiBackoff(attempt) continue } @@ -1179,6 +1250,14 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr) } + // 错误策略优先:匹配则跳过重试直接处理。 + if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched { + resp = rebuilt + break + } else { + resp = rebuilt + } + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() @@ -1220,7 +1299,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. Detail: upstreamDetail, }) - log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) sleepGeminiBackoff(attempt) continue } @@ -1261,14 +1340,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - tempMatched := false - if s.rateLimitService != nil { - tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody) - } - s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - // Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens. // This avoids Gemini SDKs failing hard during preflight token counting. + // Checked before error policy so it always works regardless of custom error codes. if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) { estimated := estimateGeminiCountTokens(body) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) @@ -1282,29 +1356,73 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. }, nil } - if tempMatched { - evBody := unwrapIfNeeded(isOAuth, respBody) - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 + // 统一错误策略:自定义错误码 + 临时不可调度 + if s.rateLimitService != nil { + switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) { + case ErrorPolicySkipped: + respBody = unwrapIfNeeded(isOAuth, respBody) + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" } - upstreamDetail = truncateString(string(evBody), maxBytes) + c.Data(http.StatusInternalServerError, contentType, respBody) + return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode) + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + evBody := unwrapIfNeeded(isOAuth, respBody) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(evBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + } + + // ErrorPolicyNone → 原有逻辑 + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + // 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁 + if resp.StatusCode == http.StatusBadRequest { + msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if isGoogleProjectConfigError(msg400) { + evBody := unwrapIfNeeded(isOAuth, respBody) + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody))) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(evBody), maxBytes) + } + log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody, RetryableOnSameAccount: true} } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: requestID, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { evBody := unwrapIfNeeded(isOAuth, respBody) @@ -1341,7 +1459,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. maxBytes = 2048 } upstreamDetail = truncateString(string(respBody), maxBytes) - log.Printf("[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ @@ -1417,6 +1535,26 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. }, nil } +// checkErrorPolicyInLoop 在重试循环内预检查错误策略。 +// 返回 true 表示策略已匹配(调用者应 break),resp 已重建可直接使用。 +// 返回 false 表示 ErrorPolicyNone,resp 已重建,调用者继续走重试逻辑。 +func (s *GeminiMessagesCompatService) checkErrorPolicyInLoop( + ctx context.Context, account *Account, resp *http.Response, +) (matched bool, rebuilt *http.Response) { + if resp.StatusCode < 400 || s.rateLimitService == nil { + return false, resp + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + rebuilt = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(body)), + } + policy := s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, body) + return policy != ErrorPolicyNone, rebuilt +} + func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool { switch statusCode { case 429, 500, 502, 503, 504, 529: @@ -1498,7 +1636,7 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc }) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) } if status, errType, errMsg, matched := applyErrorPassthroughRule( @@ -1718,12 +1856,17 @@ func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response") } - geminiResp, err := unwrapGeminiResponse(body) + unwrappedBody, err := unwrapGeminiResponse(body) if err != nil { return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") } - claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel) + var geminiResp map[string]any + if err := json.Unmarshal(unwrappedBody, &geminiResp); err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + + claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel, unwrappedBody) c.JSON(http.StatusOK, claudeResp) return usage, nil @@ -1796,11 +1939,16 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re continue } - geminiResp, err := unwrapGeminiResponse([]byte(payload)) + unwrappedBytes, err := unwrapGeminiResponse([]byte(payload)) if err != nil { continue } + var geminiResp map[string]any + if err := json.Unmarshal(unwrappedBytes, &geminiResp); err != nil { + continue + } + if fr := extractGeminiFinishReason(geminiResp); fr != "" { finishReason = fr } @@ -1927,7 +2075,7 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re } } - if u := extractGeminiUsage(geminiResp); u != nil { + if u := extractGeminiUsage(unwrappedBytes); u != nil { usage = *u } @@ -2018,11 +2166,7 @@ func unwrapIfNeeded(isOAuth bool, raw []byte) []byte { if err != nil { return raw } - b, err := json.Marshal(inner) - if err != nil { - return raw - } - return b + return inner } func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) { @@ -2046,17 +2190,20 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag } default: var parsed map[string]any + var rawBytes []byte if isOAuth { - inner, err := unwrapGeminiResponse([]byte(payload)) - if err == nil && inner != nil { - parsed = inner + innerBytes, err := unwrapGeminiResponse([]byte(payload)) + if err == nil { + rawBytes = innerBytes + _ = json.Unmarshal(innerBytes, &parsed) } } else { - _ = json.Unmarshal([]byte(payload), &parsed) + rawBytes = []byte(payload) + _ = json.Unmarshal(rawBytes, &parsed) } if parsed != nil { last = parsed - if u := extractGeminiUsage(parsed); u != nil { + if u := extractGeminiUsage(rawBytes); u != nil { usage = u } if parts := extractGeminiParts(parsed); len(parts) > 0 { @@ -2185,53 +2332,27 @@ func isGeminiInsufficientScope(headers http.Header, body []byte) bool { } func estimateGeminiCountTokens(reqBody []byte) int { - var obj map[string]any - if err := json.Unmarshal(reqBody, &obj); err != nil { - return 0 - } - - var texts []string + total := 0 // systemInstruction.parts[].text - if si, ok := obj["systemInstruction"].(map[string]any); ok { - if parts, ok := si["parts"].([]any); ok { - for _, p := range parts { - if pm, ok := p.(map[string]any); ok { - if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" { - texts = append(texts, t) - } - } - } + gjson.GetBytes(reqBody, "systemInstruction.parts").ForEach(func(_, part gjson.Result) bool { + if t := strings.TrimSpace(part.Get("text").String()); t != "" { + total += estimateTokensForText(t) } - } + return true + }) // contents[].parts[].text - if contents, ok := obj["contents"].([]any); ok { - for _, c := range contents { - cm, ok := c.(map[string]any) - if !ok { - continue + gjson.GetBytes(reqBody, "contents").ForEach(func(_, content gjson.Result) bool { + content.Get("parts").ForEach(func(_, part gjson.Result) bool { + if t := strings.TrimSpace(part.Get("text").String()); t != "" { + total += estimateTokensForText(t) } - parts, ok := cm["parts"].([]any) - if !ok { - continue - } - for _, p := range parts { - pm, ok := p.(map[string]any) - if !ok { - continue - } - if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" { - texts = append(texts, t) - } - } - } - } + return true + }) + return true + }) - total := 0 - for _, t := range texts { - total += estimateTokensForText(t) - } if total < 0 { return 0 } @@ -2269,31 +2390,39 @@ type UpstreamHTTPResult struct { } func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) { - // Log response headers for debugging - log.Printf("[GeminiAPI] ========== Response Headers ==========") - for key, values := range resp.Header { - if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { - log.Printf("[GeminiAPI] %s: %v", key, values) + if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========") + for key, values := range resp.Header { + if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + } } + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================") } - log.Printf("[GeminiAPI] ========================================") - respBody, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } - var parsed map[string]any if isOAuth { - parsed, err = unwrapGeminiResponse(respBody) - if err == nil && parsed != nil { - respBody, _ = json.Marshal(parsed) + unwrappedBody, uwErr := unwrapGeminiResponse(respBody) + if uwErr == nil { + respBody = unwrappedBody } - } else { - _ = json.Unmarshal(respBody, &parsed) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := resp.Header.Get("Content-Type") if contentType == "" { @@ -2301,26 +2430,25 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co } c.Data(resp.StatusCode, contentType, respBody) - if parsed != nil { - if u := extractGeminiUsage(parsed); u != nil { - return u, nil - } + if u := extractGeminiUsage(respBody); u != nil { + return u, nil } return &ClaudeUsage{}, nil } func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) { - // Log response headers for debugging - log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========") - for key, values := range resp.Header { - if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { - log.Printf("[GeminiAPI] %s: %v", key, values) + if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========") + for key, values := range resp.Header { + if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + } } + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================") } - log.Printf("[GeminiAPI] ====================================================") - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } c.Status(resp.StatusCode) @@ -2357,23 +2485,19 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte var rawToWrite string rawToWrite = payload - var parsed map[string]any + var rawBytes []byte if isOAuth { - inner, err := unwrapGeminiResponse([]byte(payload)) - if err == nil && inner != nil { - parsed = inner - if b, err := json.Marshal(inner); err == nil { - rawToWrite = string(b) - } + innerBytes, err := unwrapGeminiResponse([]byte(payload)) + if err == nil { + rawToWrite = string(innerBytes) + rawBytes = innerBytes } } else { - _ = json.Unmarshal([]byte(payload), &parsed) + rawBytes = []byte(payload) } - if parsed != nil { - if u := extractGeminiUsage(parsed); u != nil { - usage = u - } + if u := extractGeminiUsage(rawBytes); u != nil { + usage = u } if firstTokenMs == nil { @@ -2420,10 +2544,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac return nil, errors.New("invalid path") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, err @@ -2468,7 +2589,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) wwwAuthenticate := resp.Header.Get("Www-Authenticate") - filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders) + filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.responseHeaderFilter) if wwwAuthenticate != "" { filteredHeaders.Set("Www-Authenticate", wwwAuthenticate) } @@ -2479,19 +2600,18 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac }, nil } -func unwrapGeminiResponse(raw []byte) (map[string]any, error) { - var outer map[string]any - if err := json.Unmarshal(raw, &outer); err != nil { - return nil, err +// unwrapGeminiResponse 解包 Gemini OAuth 响应中的 response 字段 +// 使用 gjson 零拷贝提取,避免完整 Unmarshal+Marshal +func unwrapGeminiResponse(raw []byte) ([]byte, error) { + result := gjson.GetBytes(raw, "response") + if result.Exists() && result.Type == gjson.JSON { + return []byte(result.Raw), nil } - if resp, ok := outer["response"].(map[string]any); ok && resp != nil { - return resp, nil - } - return outer, nil + return raw, nil } -func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string) (map[string]any, *ClaudeUsage) { - usage := extractGeminiUsage(geminiResp) +func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string, rawData []byte) (map[string]any, *ClaudeUsage) { + usage := extractGeminiUsage(rawData) if usage == nil { usage = &ClaudeUsage{} } @@ -2555,19 +2675,20 @@ func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel strin return resp, usage } -func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage { - usageMeta, ok := geminiResp["usageMetadata"].(map[string]any) - if !ok || usageMeta == nil { +func extractGeminiUsage(data []byte) *ClaudeUsage { + usage := gjson.GetBytes(data, "usageMetadata") + if !usage.Exists() { return nil } - prompt, _ := asInt(usageMeta["promptTokenCount"]) - cand, _ := asInt(usageMeta["candidatesTokenCount"]) - cached, _ := asInt(usageMeta["cachedContentTokenCount"]) + prompt := int(usage.Get("promptTokenCount").Int()) + cand := int(usage.Get("candidatesTokenCount").Int()) + cached := int(usage.Get("cachedContentTokenCount").Int()) + thoughts := int(usage.Get("thoughtsTokenCount").Int()) // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 return &ClaudeUsage{ InputTokens: prompt - cached, - OutputTokens: cand, + OutputTokens: cand + thoughts, CacheReadInputTokens: cached, } } @@ -2592,6 +2713,10 @@ func asInt(v any) (int, bool) { } func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) { + // 遵守自定义错误码策略:未命中则跳过所有限流处理 + if !account.ShouldHandleErrorCode(statusCode) { + return + } if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) { s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) return @@ -2616,16 +2741,16 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont cooldown = s.rateLimitService.GeminiCooldown(ctx, account) } ra = time.Now().Add(cooldown) - log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second)) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second)) } else { // API Key / AI Studio OAuth: PST 午夜 if ts := nextGeminiDailyResetUnix(); ts != nil { ra = time.Unix(*ts, 0) - log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra) } else { // 兜底:5 分钟 ra = time.Now().Add(5 * time.Minute) - log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d rate limited, fallback to 5min", account.ID) } } _ = s.accountRepo.SetRateLimited(ctx, account.ID, ra) @@ -2635,45 +2760,41 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont // 使用解析到的重置时间 resetTime := time.Unix(*resetAt, 0) _ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime) - log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)", + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)", account.ID, resetTime, oauthType, tierID) } // ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳 func ParseGeminiRateLimitResetTime(body []byte) *int64 { - // Try to parse metadata.quotaResetDelay like "12.345s" - var parsed map[string]any - if err := json.Unmarshal(body, &parsed); err == nil { - if errObj, ok := parsed["error"].(map[string]any); ok { - if msg, ok := errObj["message"].(string); ok { - if looksLikeGeminiDailyQuota(msg) { - if ts := nextGeminiDailyResetUnix(); ts != nil { - return ts - } - } - } - if details, ok := errObj["details"].([]any); ok { - for _, d := range details { - dm, ok := d.(map[string]any) - if !ok { - continue - } - if meta, ok := dm["metadata"].(map[string]any); ok { - if v, ok := meta["quotaResetDelay"].(string); ok { - if dur, err := time.ParseDuration(v); err == nil { - // Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s), - // which can affect scheduling decisions around thresholds (like 10s). - ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds())) - return &ts - } - } - } - } - } + // 第一阶段:gjson 结构化提取 + errMsg := gjson.GetBytes(body, "error.message").String() + if looksLikeGeminiDailyQuota(errMsg) { + if ts := nextGeminiDailyResetUnix(); ts != nil { + return ts } } - // Match "Please retry in Xs" + // 遍历 error.details 查找 quotaResetDelay + var found *int64 + gjson.GetBytes(body, "error.details").ForEach(func(_, detail gjson.Result) bool { + v := detail.Get("metadata.quotaResetDelay").String() + if v == "" { + return true + } + if dur, err := time.ParseDuration(v); err == nil { + // Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s), + // which can affect scheduling decisions around thresholds (like 10s). + ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds())) + found = &ts + return false + } + return true + }) + if found != nil { + return found + } + + // 第二阶段:regex 回退匹配 "Please retry in Xs" matches := retryInRegex.FindStringSubmatch(string(body)) if len(matches) == 2 { if dur, err := time.ParseDuration(matches[1] + "s"); err == nil { diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index f31b40ec..7560f480 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -2,8 +2,17 @@ package service import ( "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" "strings" "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" ) // TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 @@ -129,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { } } +func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + svc := &GeminiMessagesCompatService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + GeminiDebugResponseHeaders: false, + }, + }, + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-RateLimit-Limit": []string{"60"}, + }, + Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)), + } + + usage, err := svc.handleNativeNonStreamingResponse(c, resp, false) + require.NoError(t, err) + require.NotNil(t, usage) + require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志") +} + func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) { claudeReq := map[string]any{ "model": "claude-haiku-4-5-20251001", @@ -203,3 +244,324 @@ func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s) } } + +// TestUnwrapGeminiResponse 测试 unwrapGeminiResponse 的各种输入场景 +// 关键区别:只有 response 为 JSON 对象/数组时才解包 +func TestUnwrapGeminiResponse(t *testing.T) { + // 构造 >50KB 的大型 JSON 对象 + largePadding := strings.Repeat("x", 50*1024) + largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding)) + largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding) + + tests := []struct { + name string + input []byte + expected string + wantErr bool + }{ + { + name: "正常 response 包装(JSON 对象)", + input: []byte(`{"response":{"key":"val"}}`), + expected: `{"key":"val"}`, + }, + { + name: "无包装直接返回", + input: []byte(`{"key":"val"}`), + expected: `{"key":"val"}`, + }, + { + name: "空 JSON", + input: []byte(`{}`), + expected: `{}`, + }, + { + name: "null response 返回原始 body", + input: []byte(`{"response":null}`), + expected: `{"response":null}`, + }, + { + name: "非法 JSON 返回原始 body", + input: []byte(`not json`), + expected: `not json`, + }, + { + name: "response 为基础类型 string 返回原始 body", + input: []byte(`{"response":"hello"}`), + expected: `{"response":"hello"}`, + }, + { + name: "嵌套 response 只解一层", + input: []byte(`{"response":{"response":{"inner":true}}}`), + expected: `{"response":{"inner":true}}`, + }, + { + name: "大型 JSON >50KB", + input: largeInput, + expected: largeExpected, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := unwrapGeminiResponse(tt.input) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.expected, strings.TrimSpace(string(got))) + }) + } +} + +// --------------------------------------------------------------------------- +// Task 8.1 — extractGeminiUsage 测试 +// --------------------------------------------------------------------------- + +func TestExtractGeminiUsage(t *testing.T) { + tests := []struct { + name string + input string + wantNil bool + wantUsage *ClaudeUsage + }{ + { + name: "完整 usageMetadata", + input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":50,"cachedContentTokenCount":20}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 80, + OutputTokens: 50, + CacheReadInputTokens: 20, + }, + }, + { + name: "包含 thoughtsTokenCount", + input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"thoughtsTokenCount":50}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 100, + OutputTokens: 70, + CacheReadInputTokens: 0, + }, + }, + { + name: "包含 thoughtsTokenCount 与缓存", + input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"cachedContentTokenCount":30,"thoughtsTokenCount":50}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 70, + OutputTokens: 70, + CacheReadInputTokens: 30, + }, + }, + { + name: "缺失 cachedContentTokenCount", + input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":50}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 100, + OutputTokens: 50, + CacheReadInputTokens: 0, + }, + }, + { + name: "无 usageMetadata", + input: `{"candidates":[]}`, + wantNil: true, + }, + { + // gjson 对 null 返回 Exists()=true,因此函数不会返回 nil, + // 而是返回全零的 ClaudeUsage。 + name: "null usageMetadata — gjson Exists 为 true", + input: `{"usageMetadata":null}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 0, + OutputTokens: 0, + CacheReadInputTokens: 0, + }, + }, + { + name: "零值字段", + input: `{"usageMetadata":{"promptTokenCount":0,"candidatesTokenCount":0,"cachedContentTokenCount":0}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 0, + OutputTokens: 0, + CacheReadInputTokens: 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractGeminiUsage([]byte(tt.input)) + if tt.wantNil { + if got != nil { + t.Fatalf("期望返回 nil,实际返回 %+v", got) + } + return + } + if got == nil { + t.Fatalf("期望返回非 nil,实际返回 nil") + } + if got.InputTokens != tt.wantUsage.InputTokens { + t.Errorf("InputTokens: 期望 %d,实际 %d", tt.wantUsage.InputTokens, got.InputTokens) + } + if got.OutputTokens != tt.wantUsage.OutputTokens { + t.Errorf("OutputTokens: 期望 %d,实际 %d", tt.wantUsage.OutputTokens, got.OutputTokens) + } + if got.CacheReadInputTokens != tt.wantUsage.CacheReadInputTokens { + t.Errorf("CacheReadInputTokens: 期望 %d,实际 %d", tt.wantUsage.CacheReadInputTokens, got.CacheReadInputTokens) + } + }) + } +} + +// --------------------------------------------------------------------------- +// Task 8.2 — estimateGeminiCountTokens 测试 +// --------------------------------------------------------------------------- + +func TestEstimateGeminiCountTokens(t *testing.T) { + tests := []struct { + name string + input string + wantGt0 bool // 期望结果 > 0 + wantExact *int // 如果非 nil,期望精确匹配 + }{ + { + name: "含 systemInstruction 和 contents", + input: `{ + "systemInstruction":{"parts":[{"text":"You are a helpful assistant."}]}, + "contents":[{"parts":[{"text":"Hello, how are you?"}]}] + }`, + wantGt0: true, + }, + { + name: "仅 contents,无 systemInstruction", + input: `{ + "contents":[{"parts":[{"text":"Hello, how are you?"}]}] + }`, + wantGt0: true, + }, + { + name: "空 parts", + input: `{"contents":[{"parts":[]}]}`, + wantGt0: false, + wantExact: intPtr(0), + }, + { + name: "非文本 parts(inlineData)", + input: `{"contents":[{"parts":[{"inlineData":{"mimeType":"image/png"}}]}]}`, + wantGt0: false, + wantExact: intPtr(0), + }, + { + name: "空白文本", + input: `{"contents":[{"parts":[{"text":" "}]}]}`, + wantGt0: false, + wantExact: intPtr(0), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := estimateGeminiCountTokens([]byte(tt.input)) + if tt.wantExact != nil { + if got != *tt.wantExact { + t.Errorf("期望精确值 %d,实际 %d", *tt.wantExact, got) + } + return + } + if tt.wantGt0 && got <= 0 { + t.Errorf("期望返回 > 0,实际 %d", got) + } + if !tt.wantGt0 && got != 0 { + t.Errorf("期望返回 0,实际 %d", got) + } + }) + } +} + +// --------------------------------------------------------------------------- +// Task 8.3 — ParseGeminiRateLimitResetTime 测试 +// --------------------------------------------------------------------------- + +func TestParseGeminiRateLimitResetTime(t *testing.T) { + tests := []struct { + name string + input string + wantNil bool + approxDelta int64 // 预期的 (返回值 - now) 大约是多少秒 + }{ + { + name: "正常 quotaResetDelay", + input: `{"error":{"details":[{"metadata":{"quotaResetDelay":"12.345s"}}]}}`, + wantNil: false, + approxDelta: 13, // 向上取整 12.345 -> 13 + }, + { + name: "daily quota", + input: `{"error":{"message":"quota per day exceeded"}}`, + wantNil: false, + approxDelta: -1, // 不检查精确 delta,仅检查非 nil + }, + { + name: "无 details 且无 regex 匹配", + input: `{"error":{"message":"rate limit"}}`, + wantNil: true, + }, + { + name: "regex 回退匹配", + input: `Please retry in 30s`, + wantNil: false, + approxDelta: 30, + }, + { + name: "完全无匹配", + input: `{"error":{"code":429}}`, + wantNil: true, + }, + { + name: "非法 JSON 但 regex 回退仍工作", + input: `not json but Please retry in 10s`, + wantNil: false, + approxDelta: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + now := time.Now().Unix() + got := ParseGeminiRateLimitResetTime([]byte(tt.input)) + + if tt.wantNil { + if got != nil { + t.Fatalf("期望返回 nil,实际返回 %d", *got) + } + return + } + + if got == nil { + t.Fatalf("期望返回非 nil,实际返回 nil") + } + + // approxDelta == -1 表示只检查非 nil,不检查具体值(如 daily quota 场景) + if tt.approxDelta == -1 { + // 仅验证返回的时间戳在合理范围内(未来的某个时间) + if *got < now { + t.Errorf("期望返回的时间戳 >= now(%d),实际 %d", now, *got) + } + return + } + + // 使用 +/-2 秒容差进行范围检查 + delta := *got - now + if delta < tt.approxDelta-2 || delta > tt.approxDelta+2 { + t.Errorf("期望 delta 约为 %d 秒(+/-2),实际 delta = %d 秒(返回值=%d, now=%d)", + tt.approxDelta, delta, *got, now) + } + }) + } +} diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 601e7e2c..86bc9476 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -66,12 +66,20 @@ func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account) func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) { return nil, nil } + +func (m *mockAccountRepoForGemini) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { + return nil, nil +} + +func (m *mockAccountRepoForGemini) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { + return nil, nil +} func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil } func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil } func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { @@ -133,9 +141,6 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } -func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { - return nil -} func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { return nil } @@ -226,6 +231,10 @@ func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, gr return nil, nil } +func (m *mockGroupRepoForGemini) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { + return nil +} + var _ GroupRepository = (*mockGroupRepoForGemini)(nil) // mockGatewayCacheForGemini Gemini 测试用的 cache mock @@ -265,22 +274,6 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, return nil } -func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { - return 0, nil -} - -func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) { - return nil, nil -} - -func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择 func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) { ctx := context.Background() diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index fd2932e6..08a74a37 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "log" "net/http" "regexp" "strconv" @@ -16,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) const ( @@ -54,6 +54,7 @@ type GeminiOAuthService struct { proxyRepo ProxyRepository oauthClient GeminiOAuthClient codeAssist GeminiCliCodeAssistClient + driveClient geminicli.DriveClient cfg *config.Config } @@ -66,6 +67,7 @@ func NewGeminiOAuthService( proxyRepo ProxyRepository, oauthClient GeminiOAuthClient, codeAssist GeminiCliCodeAssistClient, + driveClient geminicli.DriveClient, cfg *config.Config, ) *GeminiOAuthService { return &GeminiOAuthService{ @@ -73,6 +75,7 @@ func NewGeminiOAuthService( proxyRepo: proxyRepo, oauthClient: oauthClient, codeAssist: codeAssist, + driveClient: driveClient, cfg: cfg, } } @@ -81,8 +84,7 @@ func (s *GeminiOAuthService) GetOAuthConfig() *GeminiOAuthCapabilities { // AI Studio OAuth is only enabled when the operator configures a custom OAuth client. clientID := strings.TrimSpace(s.cfg.Gemini.OAuth.ClientID) clientSecret := strings.TrimSpace(s.cfg.Gemini.OAuth.ClientSecret) - enabled := clientID != "" && clientSecret != "" && - (clientID != geminicli.GeminiCLIOAuthClientID || clientSecret != geminicli.GeminiCLIOAuthClientSecret) + enabled := clientID != "" && clientSecret != "" && clientID != geminicli.GeminiCLIOAuthClientID return &GeminiOAuthCapabilities{ AIStudioOAuthEnabled: enabled, @@ -151,8 +153,7 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 return nil, err } - isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID && - effectiveCfg.ClientSecret == geminicli.GeminiCLIOAuthClientSecret + isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID // AI Studio OAuth requires a user-provided OAuth client (built-in Gemini CLI client is scope-restricted). if oauthType == "ai_studio" && isBuiltinClient { @@ -330,27 +331,27 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string // inferGoogleOneTier infers Google One tier from Drive storage limit func inferGoogleOneTier(storageBytes int64) string { - log.Printf("[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB)) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB)) if storageBytes <= 0 { - log.Printf("[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN") return GeminiTierGoogleOneUnknown } if storageBytes > StorageTierUnlimited { - log.Printf("[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited) return GeminiTierGoogleAIUltra } if storageBytes >= StorageTierAIPremium { - log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium) return GeminiTierGoogleAIPro } if storageBytes >= StorageTierFree { - log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree) return GeminiTierGoogleOneFree } - log.Printf("[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree) return GeminiTierGoogleOneUnknown } @@ -360,30 +361,29 @@ func inferGoogleOneTier(storageBytes int64) string { // 2. Personal accounts will get 403/404 from cloudaicompanion.googleapis.com // 3. Google consumer (Google One) and enterprise (GCP) systems are physically isolated func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) { - log.Printf("[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)") // Use Drive API to infer tier from storage quota (requires drive.readonly scope) - log.Printf("[GeminiOAuth] Calling Drive API for storage quota...") - driveClient := geminicli.NewDriveClient() + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Calling Drive API for storage quota...") - storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL) + storageInfo, err := s.driveClient.GetStorageQuota(ctx, accessToken, proxyURL) if err != nil { // Check if it's a 403 (scope not granted) if strings.Contains(err.Error(), "status 403") { - log.Printf("[GeminiOAuth] Drive API scope not available (403): %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive API scope not available (403): %v", err) return GeminiTierGoogleOneUnknown, nil, err } // Other errors - log.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Failed to fetch Drive storage: %v", err) return GeminiTierGoogleOneUnknown, nil, err } - log.Printf("[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)", + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)", storageInfo.Limit, float64(storageInfo.Limit)/float64(TB), storageInfo.Usage, float64(storageInfo.Usage)/float64(GB)) tierID := inferGoogleOneTier(storageInfo.Limit) - log.Printf("[GeminiOAuth] Inferred tier from storage: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Inferred tier from storage: %s", tierID) return tierID, storageInfo, nil } @@ -443,16 +443,16 @@ func (s *GeminiOAuthService) RefreshAccountGoogleOneTier( } func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { - log.Printf("[GeminiOAuth] ========== ExchangeCode START ==========") - log.Printf("[GeminiOAuth] SessionID: %s", input.SessionID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode START ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] SessionID: %s", input.SessionID) session, ok := s.sessionStore.Get(input.SessionID) if !ok { - log.Printf("[GeminiOAuth] ERROR: Session not found or expired") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Session not found or expired") return nil, fmt.Errorf("session not found or expired") } if strings.TrimSpace(input.State) == "" || input.State != session.State { - log.Printf("[GeminiOAuth] ERROR: Invalid state") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Invalid state") return nil, fmt.Errorf("invalid state") } @@ -463,7 +463,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch proxyURL = proxy.URL() } } - log.Printf("[GeminiOAuth] ProxyURL: %s", proxyURL) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ProxyURL: %s", proxyURL) redirectURI := session.RedirectURI @@ -472,8 +472,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if oauthType == "" { oauthType = "code_assist" } - log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType) - log.Printf("[GeminiOAuth] Project ID from session: %s", session.ProjectID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] OAuth Type: %s", oauthType) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Project ID from session: %s", session.ProjectID) // If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured. if oauthType == "ai_studio" { @@ -485,26 +485,25 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if err != nil { return nil, err } - isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID && - effectiveCfg.ClientSecret == geminicli.GeminiCLIOAuthClientSecret + isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID if isBuiltinClient { return nil, fmt.Errorf("AI Studio OAuth requires a custom OAuth Client. Please use an AI Studio API Key account, or configure GEMINI_OAUTH_CLIENT_ID / GEMINI_OAUTH_CLIENT_SECRET and re-authorize") } } - // code_assist always uses the built-in client and its fixed redirect URI. - if oauthType == "code_assist" { + // code_assist/google_one always uses the built-in client and its fixed redirect URI. + if oauthType == "code_assist" || oauthType == "google_one" { redirectURI = geminicli.GeminiCLIRedirectURI } tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL) if err != nil { - log.Printf("[GeminiOAuth] ERROR: Failed to exchange code: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Failed to exchange code: %v", err) return nil, fmt.Errorf("failed to exchange code: %w", err) } - log.Printf("[GeminiOAuth] Token exchange successful") - log.Printf("[GeminiOAuth] Token scope: %s", tokenResp.Scope) - log.Printf("[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token exchange successful") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token scope: %s", tokenResp.Scope) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn) sessionProjectID := strings.TrimSpace(session.ProjectID) s.sessionStore.Delete(input.SessionID) @@ -526,40 +525,40 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch fallbackTierID = canonicalGeminiTierIDForOAuthType(oauthType, session.TierID) } - log.Printf("[GeminiOAuth] ========== Account Type Detection START ==========") - log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== Account Type Detection START ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] OAuth Type: %s", oauthType) // 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API // 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别 // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) switch oauthType { case "code_assist": - log.Printf("[GeminiOAuth] Processing code_assist OAuth type") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing code_assist OAuth type") if projectID == "" { - log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") var err error projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { // 记录警告但不阻断流程,允许后续补充 project_id fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) - log.Printf("[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err) } else { - log.Printf("[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID) } } else { - log.Printf("[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID) // 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID _, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err) - log.Printf("[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err) } else { tierID = fetchedTierID - log.Printf("[GeminiOAuth] Successfully fetched tier_id: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched tier_id: %s", tierID) } } if strings.TrimSpace(projectID) == "" { - log.Printf("[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth") return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project") } // Prefer auto-detected tier; fall back to user-selected tier. @@ -567,31 +566,31 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if tierID == "" { if fallbackTierID != "" { tierID = fallbackTierID - log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID) } else { tierID = GeminiTierGCPStandard - log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using default tier_id: %s", tierID) } } - log.Printf("[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID) case "google_one": - log.Printf("[GeminiOAuth] Processing google_one OAuth type") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing google_one OAuth type") // Google One accounts use cloudaicompanion API, which requires a project_id. // For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API. if projectID == "" { - log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") var err error projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { - log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err) return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err) } - log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched project_id: %s", projectID) } - log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Attempting to fetch Google One tier from Drive API...") // Attempt to fetch Drive storage tier var storageInfo *geminicli.DriveStorageInfo var err error @@ -599,12 +598,12 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if err != nil { // Log warning but don't block - use fallback fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err) - log.Printf("[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err) tierID = "" } else { - log.Printf("[GeminiOAuth] Successfully fetched Drive tier: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched Drive tier: %s", tierID) if storageInfo != nil { - log.Printf("[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)", + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)", storageInfo.Limit, float64(storageInfo.Limit)/float64(TB), storageInfo.Usage, float64(storageInfo.Usage)/float64(GB)) } @@ -613,10 +612,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if tierID == "" || tierID == GeminiTierGoogleOneUnknown { if fallbackTierID != "" { tierID = fallbackTierID - log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID) } else { tierID = GeminiTierGoogleOneFree - log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using default tier_id: %s", tierID) } } fmt.Printf("[GeminiOAuth] Google One tierID after normalization: %s\n", tierID) @@ -639,7 +638,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch "drive_tier_updated_at": time.Now().Format(time.RFC3339), }, } - log.Printf("[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========") return tokenInfo, nil } @@ -652,10 +651,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch } default: - log.Printf("[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType) } - log.Printf("[GeminiOAuth] ========== Account Type Detection END ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== Account Type Detection END ==========") result := &GeminiTokenInfo{ AccessToken: tokenResp.AccessToken, @@ -668,8 +667,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch TierID: tierID, OAuthType: oauthType, } - log.Printf("[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID) - log.Printf("[GeminiOAuth] ========== ExchangeCode END ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode END ==========") return result, nil } @@ -952,23 +951,23 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr registeredTierID := strings.TrimSpace(loadResp.GetTier()) if registeredTierID != "" { // 已注册但未返回 cloudaicompanionProject,这在 Google One 用户中较常见:需要用户自行提供 project_id。 - log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID) // Try to get project from Cloud Resource Manager fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback) return strings.TrimSpace(fallback), tierID, nil } // No project found - user must provide project_id manually - log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually") return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID) } } // 未检测到 currentTier/paidTier,视为新用户,继续调用 onboardUser - log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID) req := &geminicli.OnboardUserRequest{ TierID: tierID, @@ -1046,7 +1045,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR ValidateResolvedIP: true, }) if err != nil { - client = &http.Client{Timeout: 30 * time.Second} + return "", fmt.Errorf("create http client failed: %w", err) } resp, err := client.Do(req) diff --git a/backend/internal/service/gemini_oauth_service_test.go b/backend/internal/service/gemini_oauth_service_test.go index 5591eb39..397b581d 100644 --- a/backend/internal/service/gemini_oauth_service_test.go +++ b/backend/internal/service/gemini_oauth_service_test.go @@ -1,17 +1,29 @@ +//go:build unit + package service import ( "context" + "fmt" "net/url" "strings" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) +// ===================== +// 保留原有测试 +// ===================== + func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { - t.Parallel() + // NOTE: This test sets process env; it must not run in parallel. + // The built-in Gemini CLI client secret is not embedded in this repository. + // Tests set a dummy secret via env to simulate operator-provided configuration. + t.Setenv(geminicli.GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") type testCase struct { name string @@ -89,7 +101,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg) + svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg) got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "") if tt.wantErrSubstr != "" { if err == nil { @@ -128,3 +140,1336 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { }) } } + +// ===================== +// 新增测试:validateTierID +// ===================== + +func TestValidateTierID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tierID string + wantErr bool + }{ + {name: "空字符串合法", tierID: "", wantErr: false}, + {name: "正常 tier_id", tierID: "google_one_free", wantErr: false}, + {name: "包含斜杠", tierID: "tier/sub", wantErr: false}, + {name: "包含连字符", tierID: "gcp-standard", wantErr: false}, + {name: "纯数字", tierID: "12345", wantErr: false}, + {name: "超长字符串(65个字符)", tierID: strings.Repeat("a", 65), wantErr: true}, + {name: "刚好64个字符", tierID: strings.Repeat("b", 64), wantErr: false}, + {name: "非法字符_空格", tierID: "tier id", wantErr: true}, + {name: "非法字符_中文", tierID: "tier_中文", wantErr: true}, + {name: "非法字符_特殊符号", tierID: "tier@id", wantErr: true}, + {name: "非法字符_感叹号", tierID: "tier!id", wantErr: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateTierID(tt.tierID) + if tt.wantErr && err == nil { + t.Fatalf("期望返回错误,但返回 nil") + } + if !tt.wantErr && err != nil { + t.Fatalf("不期望返回错误,但返回: %v", err) + } + }) + } +} + +// ===================== +// 新增测试:canonicalGeminiTierID +// ===================== + +func TestCanonicalGeminiTierID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw string + want string + }{ + // 空值 + {name: "空字符串", raw: "", want: ""}, + {name: "纯空白", raw: " ", want: ""}, + + // 已规范化的值(直接返回) + {name: "google_one_free", raw: "google_one_free", want: GeminiTierGoogleOneFree}, + {name: "google_ai_pro", raw: "google_ai_pro", want: GeminiTierGoogleAIPro}, + {name: "google_ai_ultra", raw: "google_ai_ultra", want: GeminiTierGoogleAIUltra}, + {name: "gcp_standard", raw: "gcp_standard", want: GeminiTierGCPStandard}, + {name: "gcp_enterprise", raw: "gcp_enterprise", want: GeminiTierGCPEnterprise}, + {name: "aistudio_free", raw: "aistudio_free", want: GeminiTierAIStudioFree}, + {name: "aistudio_paid", raw: "aistudio_paid", want: GeminiTierAIStudioPaid}, + {name: "google_one_unknown", raw: "google_one_unknown", want: GeminiTierGoogleOneUnknown}, + + // 大小写不敏感 + {name: "Google_One_Free 大写", raw: "Google_One_Free", want: GeminiTierGoogleOneFree}, + {name: "GCP_STANDARD 全大写", raw: "GCP_STANDARD", want: GeminiTierGCPStandard}, + + // legacy 映射: Google One + {name: "AI_PREMIUM -> google_ai_pro", raw: "AI_PREMIUM", want: GeminiTierGoogleAIPro}, + {name: "FREE -> google_one_free", raw: "FREE", want: GeminiTierGoogleOneFree}, + {name: "GOOGLE_ONE_BASIC -> google_one_free", raw: "GOOGLE_ONE_BASIC", want: GeminiTierGoogleOneFree}, + {name: "GOOGLE_ONE_STANDARD -> google_one_free", raw: "GOOGLE_ONE_STANDARD", want: GeminiTierGoogleOneFree}, + {name: "GOOGLE_ONE_UNLIMITED -> google_ai_ultra", raw: "GOOGLE_ONE_UNLIMITED", want: GeminiTierGoogleAIUltra}, + {name: "GOOGLE_ONE_UNKNOWN -> google_one_unknown", raw: "GOOGLE_ONE_UNKNOWN", want: GeminiTierGoogleOneUnknown}, + + // legacy 映射: Code Assist + {name: "STANDARD -> gcp_standard", raw: "STANDARD", want: GeminiTierGCPStandard}, + {name: "PRO -> gcp_standard", raw: "PRO", want: GeminiTierGCPStandard}, + {name: "LEGACY -> gcp_standard", raw: "LEGACY", want: GeminiTierGCPStandard}, + {name: "ENTERPRISE -> gcp_enterprise", raw: "ENTERPRISE", want: GeminiTierGCPEnterprise}, + {name: "ULTRA -> gcp_enterprise", raw: "ULTRA", want: GeminiTierGCPEnterprise}, + + // kebab-case + {name: "standard-tier -> gcp_standard", raw: "standard-tier", want: GeminiTierGCPStandard}, + {name: "pro-tier -> gcp_standard", raw: "pro-tier", want: GeminiTierGCPStandard}, + {name: "ultra-tier -> gcp_enterprise", raw: "ultra-tier", want: GeminiTierGCPEnterprise}, + + // 未知值 + {name: "unknown_value -> 空", raw: "unknown_value", want: ""}, + {name: "random-text -> 空", raw: "random-text", want: ""}, + + // 带空白 + {name: "带前后空白", raw: " google_one_free ", want: GeminiTierGoogleOneFree}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := canonicalGeminiTierID(tt.raw) + if got != tt.want { + t.Fatalf("canonicalGeminiTierID(%q) = %q, want %q", tt.raw, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:canonicalGeminiTierIDForOAuthType +// ===================== + +func TestCanonicalGeminiTierIDForOAuthType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + oauthType string + tierID string + want string + }{ + // google_one 类型过滤 + {name: "google_one + google_one_free", oauthType: "google_one", tierID: "google_one_free", want: GeminiTierGoogleOneFree}, + {name: "google_one + google_ai_pro", oauthType: "google_one", tierID: "google_ai_pro", want: GeminiTierGoogleAIPro}, + {name: "google_one + google_ai_ultra", oauthType: "google_one", tierID: "google_ai_ultra", want: GeminiTierGoogleAIUltra}, + {name: "google_one + gcp_standard 被过滤", oauthType: "google_one", tierID: "gcp_standard", want: ""}, + {name: "google_one + aistudio_free 被过滤", oauthType: "google_one", tierID: "aistudio_free", want: ""}, + {name: "google_one + AI_PREMIUM 遗留映射", oauthType: "google_one", tierID: "AI_PREMIUM", want: GeminiTierGoogleAIPro}, + + // code_assist 类型过滤 + {name: "code_assist + gcp_standard", oauthType: "code_assist", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + {name: "code_assist + gcp_enterprise", oauthType: "code_assist", tierID: "gcp_enterprise", want: GeminiTierGCPEnterprise}, + {name: "code_assist + google_one_free 被过滤", oauthType: "code_assist", tierID: "google_one_free", want: ""}, + {name: "code_assist + aistudio_free 被过滤", oauthType: "code_assist", tierID: "aistudio_free", want: ""}, + {name: "code_assist + STANDARD 遗留映射", oauthType: "code_assist", tierID: "STANDARD", want: GeminiTierGCPStandard}, + {name: "code_assist + standard-tier kebab", oauthType: "code_assist", tierID: "standard-tier", want: GeminiTierGCPStandard}, + + // ai_studio 类型过滤 + {name: "ai_studio + aistudio_free", oauthType: "ai_studio", tierID: "aistudio_free", want: GeminiTierAIStudioFree}, + {name: "ai_studio + aistudio_paid", oauthType: "ai_studio", tierID: "aistudio_paid", want: GeminiTierAIStudioPaid}, + {name: "ai_studio + gcp_standard 被过滤", oauthType: "ai_studio", tierID: "gcp_standard", want: ""}, + {name: "ai_studio + google_one_free 被过滤", oauthType: "ai_studio", tierID: "google_one_free", want: ""}, + + // 空值 + {name: "空 tierID", oauthType: "google_one", tierID: "", want: ""}, + {name: "空 oauthType + 有效 tierID", oauthType: "", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + {name: "未知 oauthType 接受规范化值", oauthType: "unknown_type", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + + // oauthType 大小写和空白 + {name: "GOOGLE_ONE 大写", oauthType: "GOOGLE_ONE", tierID: "google_one_free", want: GeminiTierGoogleOneFree}, + {name: "oauthType 带空白", oauthType: " code_assist ", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := canonicalGeminiTierIDForOAuthType(tt.oauthType, tt.tierID) + if got != tt.want { + t.Fatalf("canonicalGeminiTierIDForOAuthType(%q, %q) = %q, want %q", tt.oauthType, tt.tierID, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:extractTierIDFromAllowedTiers +// ===================== + +func TestExtractTierIDFromAllowedTiers(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + allowedTiers []geminicli.AllowedTier + want string + }{ + { + name: "nil 列表返回 LEGACY", + allowedTiers: nil, + want: "LEGACY", + }, + { + name: "空列表返回 LEGACY", + allowedTiers: []geminicli.AllowedTier{}, + want: "LEGACY", + }, + { + name: "有 IsDefault 的 tier", + allowedTiers: []geminicli.AllowedTier{ + {ID: "STANDARD", IsDefault: false}, + {ID: "PRO", IsDefault: true}, + {ID: "ENTERPRISE", IsDefault: false}, + }, + want: "PRO", + }, + { + name: "没有 IsDefault 取第一个非空", + allowedTiers: []geminicli.AllowedTier{ + {ID: "STANDARD", IsDefault: false}, + {ID: "ENTERPRISE", IsDefault: false}, + }, + want: "STANDARD", + }, + { + name: "IsDefault 的 ID 为空,取第一个非空", + allowedTiers: []geminicli.AllowedTier{ + {ID: "", IsDefault: true}, + {ID: "PRO", IsDefault: false}, + }, + want: "PRO", + }, + { + name: "所有 ID 都为空返回 LEGACY", + allowedTiers: []geminicli.AllowedTier{ + {ID: "", IsDefault: false}, + {ID: " ", IsDefault: false}, + }, + want: "LEGACY", + }, + { + name: "ID 带空白会被 trim", + allowedTiers: []geminicli.AllowedTier{ + {ID: " STANDARD ", IsDefault: true}, + }, + want: "STANDARD", + }, + { + name: "单个 tier 且 IsDefault", + allowedTiers: []geminicli.AllowedTier{ + {ID: "ENTERPRISE", IsDefault: true}, + }, + want: "ENTERPRISE", + }, + { + name: "单个 tier 非 IsDefault", + allowedTiers: []geminicli.AllowedTier{ + {ID: "ENTERPRISE", IsDefault: false}, + }, + want: "ENTERPRISE", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := extractTierIDFromAllowedTiers(tt.allowedTiers) + if got != tt.want { + t.Fatalf("extractTierIDFromAllowedTiers() = %q, want %q", got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:inferGoogleOneTier +// ===================== + +func TestInferGoogleOneTier(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + storageBytes int64 + want string + }{ + // 边界:<= 0 + {name: "0 bytes -> unknown", storageBytes: 0, want: GeminiTierGoogleOneUnknown}, + {name: "负数 -> unknown", storageBytes: -1, want: GeminiTierGoogleOneUnknown}, + + // > 100TB -> ultra + {name: "> 100TB -> ultra", storageBytes: int64(StorageTierUnlimited) + 1, want: GeminiTierGoogleAIUltra}, + {name: "200TB -> ultra", storageBytes: 200 * int64(TB), want: GeminiTierGoogleAIUltra}, + + // >= 2TB -> pro (但 <= 100TB) + {name: "正好 2TB -> pro", storageBytes: int64(StorageTierAIPremium), want: GeminiTierGoogleAIPro}, + {name: "5TB -> pro", storageBytes: 5 * int64(TB), want: GeminiTierGoogleAIPro}, + {name: "100TB 正好 -> pro (不是 > 100TB)", storageBytes: int64(StorageTierUnlimited), want: GeminiTierGoogleAIPro}, + + // >= 15GB -> free (但 < 2TB) + {name: "正好 15GB -> free", storageBytes: int64(StorageTierFree), want: GeminiTierGoogleOneFree}, + {name: "100GB -> free", storageBytes: 100 * int64(GB), want: GeminiTierGoogleOneFree}, + {name: "略低于 2TB -> free", storageBytes: int64(StorageTierAIPremium) - 1, want: GeminiTierGoogleOneFree}, + + // < 15GB -> unknown + {name: "1GB -> unknown", storageBytes: int64(GB), want: GeminiTierGoogleOneUnknown}, + {name: "略低于 15GB -> unknown", storageBytes: int64(StorageTierFree) - 1, want: GeminiTierGoogleOneUnknown}, + {name: "1 byte -> unknown", storageBytes: 1, want: GeminiTierGoogleOneUnknown}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := inferGoogleOneTier(tt.storageBytes) + if got != tt.want { + t.Fatalf("inferGoogleOneTier(%d) = %q, want %q", tt.storageBytes, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:isNonRetryableGeminiOAuthError +// ===================== + +func TestIsNonRetryableGeminiOAuthError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + {name: "invalid_grant", err: fmt.Errorf("error: invalid_grant"), want: true}, + {name: "invalid_client", err: fmt.Errorf("oauth error: invalid_client"), want: true}, + {name: "unauthorized_client", err: fmt.Errorf("unauthorized_client: mismatch"), want: true}, + {name: "access_denied", err: fmt.Errorf("access_denied by user"), want: true}, + {name: "普通网络错误", err: fmt.Errorf("connection timeout"), want: false}, + {name: "HTTP 500 错误", err: fmt.Errorf("server error 500"), want: false}, + {name: "空错误信息", err: fmt.Errorf(""), want: false}, + {name: "包含 invalid 但不是完整匹配", err: fmt.Errorf("invalid request"), want: false}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := isNonRetryableGeminiOAuthError(tt.err) + if got != tt.want { + t.Fatalf("isNonRetryableGeminiOAuthError(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:BuildAccountCredentials +// ===================== + +func TestGeminiOAuthService_BuildAccountCredentials(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + defer svc.Stop() + + t.Run("完整字段", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "access-123", + RefreshToken: "refresh-456", + ExpiresIn: 3600, + ExpiresAt: 1700000000, + TokenType: "Bearer", + Scope: "openid email", + ProjectID: "my-project", + TierID: "gcp_standard", + OAuthType: "code_assist", + Extra: map[string]any{ + "drive_storage_limit": int64(2199023255552), + }, + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + assertCredStr(t, creds, "access_token", "access-123") + assertCredStr(t, creds, "refresh_token", "refresh-456") + assertCredStr(t, creds, "token_type", "Bearer") + assertCredStr(t, creds, "scope", "openid email") + assertCredStr(t, creds, "project_id", "my-project") + assertCredStr(t, creds, "tier_id", "gcp_standard") + assertCredStr(t, creds, "oauth_type", "code_assist") + assertCredStr(t, creds, "expires_at", "1700000000") + + if _, ok := creds["drive_storage_limit"]; !ok { + t.Fatal("extra 字段 drive_storage_limit 未包含在 creds 中") + } + }) + + t.Run("最小字段(仅 access_token 和 expires_at)", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token-only", + ExpiresAt: 1700000000, + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + assertCredStr(t, creds, "access_token", "token-only") + assertCredStr(t, creds, "expires_at", "1700000000") + + // 可选字段不应存在 + for _, key := range []string{"refresh_token", "token_type", "scope", "project_id", "tier_id", "oauth_type"} { + if _, ok := creds[key]; ok { + t.Fatalf("不应包含空字段 %q", key) + } + } + }) + + t.Run("无效 tier_id 被静默跳过", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token", + ExpiresAt: 1700000000, + TierID: "tier with spaces", + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + if _, ok := creds["tier_id"]; ok { + t.Fatal("无效 tier_id 不应被存入 creds") + } + }) + + t.Run("超长 tier_id 被静默跳过", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token", + ExpiresAt: 1700000000, + TierID: strings.Repeat("x", 65), + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + if _, ok := creds["tier_id"]; ok { + t.Fatal("超长 tier_id 不应被存入 creds") + } + }) + + t.Run("无 extra 字段", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token", + ExpiresAt: 1700000000, + RefreshToken: "rt", + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + // 仅包含基础字段 + if len(creds) != 3 { // access_token, expires_at, refresh_token + t.Fatalf("creds 字段数量不匹配: got=%d want=3, keys=%v", len(creds), credKeys(creds)) + } + }) +} + +// ===================== +// 新增测试:GetOAuthConfig +// ===================== + +func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.Config + wantEnabled bool + }{ + { + name: "无自定义 OAuth 客户端", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{}, + }, + }, + wantEnabled: false, + }, + { + name: "仅 ClientID 无 ClientSecret", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: "custom-id", + }, + }, + }, + wantEnabled: false, + }, + { + name: "仅 ClientSecret 无 ClientID", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientSecret: "custom-secret", + }, + }, + }, + wantEnabled: false, + }, + { + name: "使用内置 Gemini CLI ClientID(不算自定义)", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: geminicli.GeminiCLIOAuthClientID, + ClientSecret: "some-secret", + }, + }, + }, + wantEnabled: false, + }, + { + name: "自定义 OAuth 客户端(非内置 ID)", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: "my-custom-client-id", + ClientSecret: "my-custom-client-secret", + }, + }, + }, + wantEnabled: true, + }, + { + name: "带空白的自定义客户端", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: " my-custom-client-id ", + ClientSecret: " my-custom-client-secret ", + }, + }, + }, + wantEnabled: true, + }, + { + name: "纯空白字符串不算配置", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: " ", + ClientSecret: " ", + }, + }, + }, + wantEnabled: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg) + defer svc.Stop() + + result := svc.GetOAuthConfig() + if result.AIStudioOAuthEnabled != tt.wantEnabled { + t.Fatalf("AIStudioOAuthEnabled = %v, want %v", result.AIStudioOAuthEnabled, tt.wantEnabled) + } + // RequiredRedirectURIs 始终包含 AI Studio redirect URI + if len(result.RequiredRedirectURIs) != 1 || result.RequiredRedirectURIs[0] != geminicli.AIStudioOAuthRedirectURI { + t.Fatalf("RequiredRedirectURIs 不匹配: got=%v", result.RequiredRedirectURIs) + } + }) + } +} + +// ===================== +// 新增测试:GeminiOAuthService.Stop +// ===================== + +func TestGeminiOAuthService_Stop_NoPanic(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + + // 调用 Stop 不应 panic + svc.Stop() + // 多次调用也不应 panic + svc.Stop() +} + +// ===================== +// mock: GeminiOAuthClient +// ===================== + +type mockGeminiOAuthClient struct { + exchangeCodeFunc func(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) + refreshTokenFunc func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) +} + +func (m *mockGeminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) { + if m.exchangeCodeFunc != nil { + return m.exchangeCodeFunc(ctx, oauthType, code, codeVerifier, redirectURI, proxyURL) + } + panic("ExchangeCode not implemented") +} + +func (m *mockGeminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + if m.refreshTokenFunc != nil { + return m.refreshTokenFunc(ctx, oauthType, refreshToken, proxyURL) + } + panic("RefreshToken not implemented") +} + +// ===================== +// mock: GeminiCliCodeAssistClient +// ===================== + +type mockGeminiCodeAssistClient struct { + loadCodeAssistFunc func(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) + onboardUserFunc func(ctx context.Context, accessToken, proxyURL string, req *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) +} + +func (m *mockGeminiCodeAssistClient) LoadCodeAssist(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + if m.loadCodeAssistFunc != nil { + return m.loadCodeAssistFunc(ctx, accessToken, proxyURL, req) + } + panic("LoadCodeAssist not implemented") +} + +func (m *mockGeminiCodeAssistClient) OnboardUser(ctx context.Context, accessToken, proxyURL string, req *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) { + if m.onboardUserFunc != nil { + return m.onboardUserFunc(ctx, accessToken, proxyURL, req) + } + panic("OnboardUser not implemented") +} + +// ===================== +// mock: ProxyRepository (最小实现) +// ===================== + +type mockGeminiProxyRepo struct { + getByIDFunc func(ctx context.Context, id int64) (*Proxy, error) +} + +func (m *mockGeminiProxyRepo) Create(ctx context.Context, proxy *Proxy) error { panic("not impl") } +func (m *mockGeminiProxyRepo) GetByID(ctx context.Context, id int64) (*Proxy, error) { + if m.getByIDFunc != nil { + return m.getByIDFunc(ctx, id) + } + return nil, fmt.Errorf("proxy not found") +} +func (m *mockGeminiProxyRepo) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) Update(ctx context.Context, proxy *Proxy) error { panic("not impl") } +func (m *mockGeminiProxyRepo) Delete(ctx context.Context, id int64) error { panic("not impl") } +func (m *mockGeminiProxyRepo) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListActive(ctx context.Context) ([]Proxy, error) { panic("not impl") } +func (m *mockGeminiProxyRepo) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) { + panic("not impl") +} + +// mockDriveClient implements geminicli.DriveClient for tests. +type mockDriveClient struct { + getStorageQuotaFunc func(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error) +} + +func (m *mockDriveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error) { + if m.getStorageQuotaFunc != nil { + return m.getStorageQuotaFunc(ctx, accessToken, proxyURL) + } + return nil, fmt.Errorf("drive API not available in test") +} + +// ===================== +// 新增测试:GeminiOAuthService.RefreshToken(含重试逻辑) +// ===================== + +func TestGeminiOAuthService_RefreshToken_Success(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "new-access", + RefreshToken: "new-refresh", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: "openid", + }, nil + }, + } + + svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{}) + defer svc.Stop() + + info, err := svc.RefreshToken(context.Background(), "code_assist", "old-refresh", "") + if err != nil { + t.Fatalf("RefreshToken 返回错误: %v", err) + } + if info.AccessToken != "new-access" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } + if info.RefreshToken != "new-refresh" { + t.Fatalf("RefreshToken 不匹配: got=%q", info.RefreshToken) + } + if info.ExpiresAt == 0 { + t.Fatal("ExpiresAt 不应为 0") + } +} + +func TestGeminiOAuthService_RefreshToken_NonRetryableError(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return nil, fmt.Errorf("invalid_grant: token revoked") + }, + } + + svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{}) + defer svc.Stop() + + _, err := svc.RefreshToken(context.Background(), "code_assist", "revoked-token", "") + if err == nil { + t.Fatal("RefreshToken 应返回错误(不可重试的 invalid_grant)") + } + if !strings.Contains(err.Error(), "invalid_grant") { + t.Fatalf("错误应包含 invalid_grant: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) { + t.Parallel() + + callCount := 0 + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + callCount++ + if callCount <= 2 { + return nil, fmt.Errorf("temporary network error") + } + return &geminicli.TokenResponse{ + AccessToken: "recovered", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{}) + defer svc.Stop() + + info, err := svc.RefreshToken(context.Background(), "code_assist", "rt", "") + if err != nil { + t.Fatalf("RefreshToken 应在重试后成功: %v", err) + } + if info.AccessToken != "recovered" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } + if callCount < 3 { + t.Fatalf("应至少调用 3 次(2 次失败 + 1 次成功): got=%d", callCount) + } +} + +// ===================== +// 新增测试:GeminiOAuthService.RefreshAccountToken +// ===================== + +func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(非 Gemini OAuth 账号)") + } + if !strings.Contains(err.Error(), "not a Gemini OAuth account") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "at", + "oauth_type": "code_assist", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(无 refresh_token)") + } + if !strings.Contains(err.Error(), "no refresh token") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_AIStudio(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "refreshed-at", + RefreshToken: "refreshed-rt", + ExpiresIn: 3600, + TokenType: "Bearer", + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-at", + "refresh_token": "old-rt", + "oauth_type": "ai_studio", + "tier_id": "aistudio_free", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.AccessToken != "refreshed-at" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } + if info.OAuthType != "ai_studio" { + t.Fatalf("OAuthType 不匹配: got=%q", info.OAuthType) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_WithProjectID(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "refreshed", + RefreshToken: "new-rt", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-at", + "refresh_token": "old-rt", + "oauth_type": "code_assist", + "project_id": "my-project", + "tier_id": "gcp_standard", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.ProjectID != "my-project" { + t.Fatalf("ProjectID 应保留: got=%q", info.ProjectID) + } + if info.TierID != GeminiTierGCPStandard { + t.Fatalf("TierID 不匹配: got=%q want=%q", info.TierID, GeminiTierGCPStandard) + } + if info.OAuthType != "code_assist" { + t.Fatalf("OAuthType 不匹配: got=%q", info.OAuthType) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_DefaultOAuthType(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + if oauthType != "code_assist" { + t.Errorf("默认 oauthType 应为 code_assist: got=%q", oauthType) + } + return &geminicli.TokenResponse{ + AccessToken: "refreshed", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) + defer svc.Stop() + + // 无 oauth_type 凭据的旧账号 + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "old-rt", + "project_id": "proj", + "tier_id": "STANDARD", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.OAuthType != "code_assist" { + t.Fatalf("OAuthType 应默认为 code_assist: got=%q", info.OAuthType) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_WithProxy(t *testing.T) { + t.Parallel() + + proxyRepo := &mockGeminiProxyRepo{ + getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) { + return &Proxy{ + Protocol: "http", + Host: "proxy.test", + Port: 3128, + }, nil + }, + } + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + if proxyURL != "http://proxy.test:3128" { + t.Errorf("proxyURL 不匹配: got=%q", proxyURL) + } + return &geminicli.TokenResponse{ + AccessToken: "refreshed", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(proxyRepo, client, nil, nil, &config.Config{}) + defer svc.Stop() + + proxyID := int64(5) + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + ProxyID: &proxyID, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + "project_id": "proj", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_AutoDetect(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + codeAssist := &mockGeminiCodeAssistClient{ + loadCodeAssistFunc: func(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + return &geminicli.LoadCodeAssistResponse{ + CloudAICompanionProject: "auto-project-123", + CurrentTier: &geminicli.TierInfo{ID: "STANDARD"}, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + // 无 project_id,触发 fetchProjectID + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.ProjectID != "auto-project-123" { + t.Fatalf("ProjectID 应为自动检测值: got=%q", info.ProjectID) + } + if info.TierID != GeminiTierGCPStandard { + t.Fatalf("TierID 不匹配: got=%q", info.TierID) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_FailsEmpty(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + // 返回有 currentTier 但无 cloudaicompanionProject 的响应, + // 使 fetchProjectID 走"已注册用户"路径(尝试 Cloud Resource Manager -> 失败 -> 返回错误), + // 避免走 onboardUser 路径(5 次重试 x 2 秒 = 10 秒超时) + codeAssist := &mockGeminiCodeAssistClient{ + loadCodeAssistFunc: func(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + return &geminicli.LoadCodeAssistResponse{ + CurrentTier: &geminicli.TierInfo{ID: "STANDARD"}, + // 无 CloudAICompanionProject + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(无法检测 project_id)") + } + if !strings.Contains(err.Error(), "project_id") { + t.Fatalf("错误信息应包含 project_id: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_FreshCache(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "google_one", + "project_id": "proj", + "tier_id": "google_ai_pro", + }, + Extra: map[string]any{ + // 缓存刷新时间在 24 小时内 + "drive_tier_updated_at": time.Now().Add(-1 * time.Hour).Format(time.RFC3339), + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + // 缓存新鲜,应使用已有的 tier_id + if info.TierID != GeminiTierGoogleAIPro { + t.Fatalf("TierID 应使用缓存值: got=%q want=%q", info.TierID, GeminiTierGoogleAIPro) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_NoTierID_DefaultsFree(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &mockDriveClient{}, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "google_one", + "project_id": "proj", + // 无 tier_id + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + // FetchGoogleOneTier 会被调用但 oauthClient(此处 mock)不实现 Drive API, + // svc.FetchGoogleOneTier 使用真实 DriveClient 会失败,最终回退到默认值。 + // 由于没有 tier_id 且 FetchGoogleOneTier 失败,应默认为 google_one_free + if info.TierID != GeminiTierGoogleOneFree { + t.Fatalf("TierID 应为默认 free: got=%q", info.TierID) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_Fallback(t *testing.T) { + t.Parallel() + + callCount := 0 + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + callCount++ + if oauthType == "code_assist" { + return nil, fmt.Errorf("unauthorized_client: client mismatch") + } + // ai_studio 路径成功 + return &geminicli.TokenResponse{ + AccessToken: "recovered", + ExpiresIn: 3600, + }, nil + }, + } + + // 启用自定义 OAuth 客户端以触发 fallback 路径 + cfg := &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + }, + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, cfg) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + "project_id": "proj", + "tier_id": "gcp_standard", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 应在 fallback 后成功: %v", err) + } + if info.AccessToken != "recovered" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return nil, fmt.Errorf("unauthorized_client: client mismatch") + }, + } + + // 无自定义 OAuth 客户端,无法 fallback + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + "project_id": "proj", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(无 fallback)") + } + if !strings.Contains(err.Error(), "OAuth client mismatch") { + t.Fatalf("错误应包含 OAuth client mismatch: got=%q", err.Error()) + } +} + +// ===================== +// 新增测试:GeminiOAuthService.ExchangeCode +// ===================== + +func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + defer svc.Stop() + + _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ + SessionID: "nonexistent", + State: "some-state", + Code: "some-code", + }) + if err == nil { + t.Fatal("应返回错误(session 不存在)") + } + if !strings.Contains(err.Error(), "session not found") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + defer svc.Stop() + + // 手动创建 session(必须设置 CreatedAt,否则会因 TTL 过期被拒绝) + svc.sessionStore.Set("test-session", &geminicli.OAuthSession{ + State: "correct-state", + CodeVerifier: "verifier", + OAuthType: "ai_studio", + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ + SessionID: "test-session", + State: "wrong-state", + Code: "code", + }) + if err == nil { + t.Fatal("应返回错误(state 不匹配)") + } + if !strings.Contains(err.Error(), "invalid state") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_ExchangeCode_EmptyState(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + defer svc.Stop() + + svc.sessionStore.Set("test-session", &geminicli.OAuthSession{ + State: "correct-state", + CodeVerifier: "verifier", + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ + SessionID: "test-session", + State: "", + Code: "code", + }) + if err == nil { + t.Fatal("应返回错误(空 state)") + } +} + +// ===================== +// 辅助函数 +// ===================== + +func assertCredStr(t *testing.T, creds map[string]any, key, want string) { + t.Helper() + raw, ok := creds[key] + if !ok { + t.Fatalf("creds 缺少 key=%q", key) + } + got, ok := raw.(string) + if !ok { + t.Fatalf("creds[%q] 不是 string: %T", key, raw) + } + if got != want { + t.Fatalf("creds[%q] = %q, want %q", key, got, want) + } +} + +func credKeys(m map[string]any) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} diff --git a/backend/internal/service/gemini_session.go b/backend/internal/service/gemini_session.go index 859ae9f3..1780d1da 100644 --- a/backend/internal/service/gemini_session.go +++ b/backend/internal/service/gemini_session.go @@ -6,26 +6,11 @@ import ( "encoding/json" "strconv" "strings" - "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/cespare/xxhash/v2" ) -// Gemini 会话 ID Fallback 相关常量 -const ( - // geminiSessionTTLSeconds Gemini 会话缓存 TTL(5 分钟) - geminiSessionTTLSeconds = 300 - - // geminiSessionKeyPrefix Gemini 会话 Redis key 前缀 - geminiSessionKeyPrefix = "gemini:sess:" -) - -// GeminiSessionTTL 返回 Gemini 会话缓存 TTL -func GeminiSessionTTL() time.Duration { - return geminiSessionTTLSeconds * time.Second -} - // shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符) // XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20% func shortHash(data []byte) string { @@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m return base64.RawURLEncoding.EncodeToString(hash[:12]) } -// BuildGeminiSessionKey 构建 Gemini 会话 Redis key -// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain} -func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string { - return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain -} - -// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短) -// 用于 MGET 批量查询最长匹配 -func GenerateDigestChainPrefixes(chain string) []string { - if chain == "" { - return nil - } - - var prefixes []string - c := chain - - for c != "" { - prefixes = append(prefixes, c) - // 找到最后一个 "-" 的位置 - if i := strings.LastIndex(c, "-"); i > 0 { - c = c[:i] - } else { - break - } - } - - return prefixes -} - // ParseGeminiSessionValue 解析 Gemini 会话缓存值 // 格式: {uuid}:{accountID} func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) { @@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string { // geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀 const geminiDigestSessionKeyPrefix = "gemini:digest:" -// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀 -const geminiTrieKeyPrefix = "gemini:trie:" - -// BuildGeminiTrieKey 构建 Gemini Trie Redis key -// 格式: gemini:trie:{groupID}:{prefixHash} -func BuildGeminiTrieKey(groupID int64, prefixHash string) string { - return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash -} - // GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey // 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey // 用于在 SelectAccountWithLoadAwareness 中保持粘性会话 diff --git a/backend/internal/service/gemini_session_integration_test.go b/backend/internal/service/gemini_session_integration_test.go index 928c62cf..95b5f594 100644 --- a/backend/internal/service/gemini_session_integration_test.go +++ b/backend/internal/service/gemini_session_integration_test.go @@ -1,41 +1,14 @@ package service import ( - "context" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) -// mockGeminiSessionCache 模拟 Redis 缓存 -type mockGeminiSessionCache struct { - sessions map[string]string // key -> value -} - -func newMockGeminiSessionCache() *mockGeminiSessionCache { - return &mockGeminiSessionCache{sessions: make(map[string]string)} -} - -func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) { - key := BuildGeminiSessionKey(groupID, prefixHash, digestChain) - value := FormatGeminiSessionValue(uuid, accountID) - m.sessions[key] = value -} - -func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - prefixes := GenerateDigestChainPrefixes(digestChain) - for _, p := range prefixes { - key := BuildGeminiSessionKey(groupID, prefixHash, p) - if val, ok := m.sessions[key]; ok { - return ParseGeminiSessionValue(val) - } - } - return "", 0, false -} - // TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配 func TestGeminiSessionContinuousConversation(t *testing.T) { - cache := newMockGeminiSessionCache() + store := NewDigestSessionStore() groupID := int64(1) prefixHash := "test_prefix_hash" sessionUUID := "session-uuid-12345" @@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Logf("Round 1 chain: %s", chain1) // 第一轮:没有找到会话,创建新会话 - _, _, found := cache.Find(groupID, prefixHash, chain1) + _, _, _, found := store.Find(groupID, prefixHash, chain1) if found { t.Error("Round 1: should not find existing session") } - // 保存第一轮会话 - cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID) + // 保存第一轮会话(首轮无旧 chain) + store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "") // 模拟第二轮对话(用户继续对话) req2 := &antigravity.GeminiRequest{ @@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Logf("Round 2 chain: %s", chain2) // 第二轮:应该能找到会话(通过前缀匹配) - foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2) + foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2) if !found { t.Error("Round 2: should find session via prefix matching") } @@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID) } - // 保存第二轮会话 - cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID) + // 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key + store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain) // 模拟第三轮对话 req3 := &antigravity.GeminiRequest{ @@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Logf("Round 3 chain: %s", chain3) // 第三轮:应该能找到会话(通过第二轮的前缀匹配) - foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3) + foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3) if !found { t.Error("Round 3: should find session via prefix matching") } @@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { if foundAccID != accountID { t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID) } - - t.Log("✓ Continuous conversation session matching works correctly!") } // TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配 func TestGeminiSessionDifferentConversations(t *testing.T) { - cache := newMockGeminiSessionCache() + store := NewDigestSessionStore() groupID := int64(1) prefixHash := "test_prefix_hash" @@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) { }, } chain1 := BuildGeminiDigestChain(req1) - cache.Save(groupID, prefixHash, chain1, "session-1", 100) + store.Save(groupID, prefixHash, chain1, "session-1", 100, "") // 第二个完全不同的会话 req2 := &antigravity.GeminiRequest{ @@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) { chain2 := BuildGeminiDigestChain(req2) // 不同会话不应该匹配 - _, _, found := cache.Find(groupID, prefixHash, chain2) + _, _, _, found := store.Find(groupID, prefixHash, chain2) if found { t.Error("Different conversations should not match") } - - t.Log("✓ Different conversations are correctly isolated!") } // TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先) func TestGeminiSessionPrefixMatchingOrder(t *testing.T) { - cache := newMockGeminiSessionCache() + store := NewDigestSessionStore() groupID := int64(1) prefixHash := "test_prefix_hash" - // 创建一个三轮对话 - req := &antigravity.GeminiRequest{ - SystemInstruction: &antigravity.GeminiContent{ - Parts: []antigravity.GeminiPart{{Text: "System prompt"}}, - }, - Contents: []antigravity.GeminiContent{ - {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}}, - {Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}}, - {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}}, - }, - } - fullChain := BuildGeminiDigestChain(req) - prefixes := GenerateDigestChainPrefixes(fullChain) - - t.Logf("Full chain: %s", fullChain) - t.Logf("Prefixes (longest first): %v", prefixes) - - // 验证前缀生成顺序(从长到短) - if len(prefixes) != 4 { - t.Errorf("Expected 4 prefixes, got %d", len(prefixes)) - } - // 保存不同轮次的会话到不同账号 - // 第一轮(最短前缀)-> 账号 1 - cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1) - // 第二轮 -> 账号 2 - cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2) - // 第三轮(最长前缀,完整链)-> 账号 3 - cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3) + store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "") + store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "") + store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "") - // 查找应该返回最长匹配(账号 3) - _, accID, found := cache.Find(groupID, prefixHash, fullChain) + // 查找更长的链,应该返回最长匹配(账号 3) + _, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2") if !found { t.Error("Should find session") } if accID != 3 { t.Errorf("Should match longest prefix (account 3), got account %d", accID) } - - t.Log("✓ Longest prefix matching works correctly!") } - -// 确保 context 包被使用(避免未使用的导入警告) -var _ = context.Background diff --git a/backend/internal/service/gemini_session_test.go b/backend/internal/service/gemini_session_test.go index 8c1908f7..a034cddd 100644 --- a/backend/internal/service/gemini_session_test.go +++ b/backend/internal/service/gemini_session_test.go @@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) { } } -func TestGenerateDigestChainPrefixes(t *testing.T) { - tests := []struct { - name string - chain string - want []string - wantLen int - }{ - { - name: "empty", - chain: "", - wantLen: 0, - }, - { - name: "single part", - chain: "u:abc123", - want: []string{"u:abc123"}, - wantLen: 1, - }, - { - name: "two parts", - chain: "s:xyz-u:abc", - want: []string{"s:xyz-u:abc", "s:xyz"}, - wantLen: 2, - }, - { - name: "four parts", - chain: "s:a-u:b-m:c-u:d", - want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"}, - wantLen: 4, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := GenerateDigestChainPrefixes(tt.chain) - - if len(result) != tt.wantLen { - t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result) - } - - if tt.want != nil { - for i, want := range tt.want { - if i >= len(result) { - t.Errorf("missing prefix at index %d", i) - continue - } - if result[i] != want { - t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i]) - } - } - } - }) - } -} - func TestParseGeminiSessionValue(t *testing.T) { tests := []struct { name string @@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) { } }) } - -func TestBuildGeminiTrieKey(t *testing.T) { - tests := []struct { - name string - groupID int64 - prefixHash string - want string - }{ - { - name: "normal", - groupID: 123, - prefixHash: "abcdef12", - want: "gemini:trie:123:abcdef12", - }, - { - name: "zero group", - groupID: 0, - prefixHash: "xyz", - want: "gemini:trie:0:xyz", - }, - { - name: "empty prefix", - groupID: 1, - prefixHash: "", - want: "gemini:trie:1:", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash) - if got != tt.want { - t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want) - } - }) - } -} diff --git a/backend/internal/service/generate_session_hash_test.go b/backend/internal/service/generate_session_hash_test.go new file mode 100644 index 00000000..8aa358a5 --- /dev/null +++ b/backend/internal/service/generate_session_hash_test.go @@ -0,0 +1,1213 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// ============ 基础优先级测试 ============ + +func TestGenerateSessionHash_NilParsedRequest(t *testing.T) { + svc := &GatewayService{} + require.Empty(t, svc.GenerateSessionHash(nil)) +} + +func TestGenerateSessionHash_EmptyRequest(t *testing.T) { + svc := &GatewayService{} + require.Empty(t, svc.GenerateSessionHash(&ParsedRequest{})) +} + +func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + hash := svc.GenerateSessionHash(parsed) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", hash, "metadata session_id should have highest priority") +} + +// ============ System + Messages 基础测试 ============ + +func TestGenerateSessionHash_SystemPlusMessages(t *testing.T) { + svc := &GatewayService{} + + withSystem := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + withoutSystem := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + h1 := svc.GenerateSessionHash(withSystem) + h2 := svc.GenerateSessionHash(withoutSystem) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "system prompt should be part of digest, producing different hash") +} + +func TestGenerateSessionHash_SystemOnlyProducesHash(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + } + hash := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, hash, "system prompt alone should produce a hash as part of full digest") +} + +func TestGenerateSessionHash_DifferentSystemsSameMessages(t *testing.T) { + svc := &GatewayService{} + + parsed1 := &ParsedRequest{ + System: "You are assistant A.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + parsed2 := &ParsedRequest{ + System: "You are assistant B.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "different system prompts with same messages should produce different hashes") +} + +func TestGenerateSessionHash_SameSystemSameMessages(t *testing.T) { + svc := &GatewayService{} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi"}, + }, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same system + same messages should produce identical hash") +} + +func TestGenerateSessionHash_DifferentMessagesProduceDifferentHash(t *testing.T) { + svc := &GatewayService{} + + parsed1 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "help me with Go"}, + }, + } + parsed2 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "help me with Python"}, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "same system but different messages should produce different hashes") +} + +// ============ SessionContext 核心测试 ============ + +func TestGenerateSessionHash_DifferentSessionContextProducesDifferentHash(t *testing.T) { + svc := &GatewayService{} + + // 相同消息 + 不同 SessionContext → 不同 hash(解决碰撞问题的核心场景) + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 100, + }, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "curl/7.0", + APIKeyID: 200, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "same messages but different SessionContext should produce different hashes") +} + +func TestGenerateSessionHash_SameSessionContextProducesSameHash(t *testing.T) { + svc := &GatewayService{} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 100, + }, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same messages + same SessionContext should produce identical hash") +} + +func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 100, + }, + } + + hash := svc.GenerateSessionHash(parsed) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", hash, + "metadata session_id should take priority over SessionContext") +} + +func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) { + svc := &GatewayService{} + + withCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: nil, + } + withoutCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + h1 := svc.GenerateSessionHash(withCtx) + h2 := svc.GenerateSessionHash(withoutCtx) + require.Equal(t, h1, h2, "nil SessionContext should produce same hash as no SessionContext") +} + +// ============ 多轮连续会话测试 ============ + +func TestGenerateSessionHash_ContinuousConversation_HashChangesWithMessages(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 模拟连续会话:每增加一轮对话,hash 应该不同(内容累积变化) + round1 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: ctx, + } + + round2 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "Hi there!"}, + map[string]any{"role": "user", "content": "How are you?"}, + }, + SessionContext: ctx, + } + + round3 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "Hi there!"}, + map[string]any{"role": "user", "content": "How are you?"}, + map[string]any{"role": "assistant", "content": "I'm doing well!"}, + map[string]any{"role": "user", "content": "Tell me a joke"}, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(round1) + h2 := svc.GenerateSessionHash(round2) + h3 := svc.GenerateSessionHash(round3) + + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEmpty(t, h3) + require.NotEqual(t, h1, h2, "different conversation rounds should produce different hashes") + require.NotEqual(t, h2, h3, "each new round should produce a different hash") + require.NotEqual(t, h1, h3, "round 1 and round 3 should differ") +} + +func TestGenerateSessionHash_ContinuousConversation_SameRoundSameHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 同一轮对话重复请求(如重试)应产生相同 hash + mk := func() *ParsedRequest { + return &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "Hi there!"}, + map[string]any{"role": "user", "content": "How are you?"}, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same conversation state should produce identical hash on retry") +} + +// ============ 消息回退测试 ============ + +func TestGenerateSessionHash_MessageRollback(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 模拟消息回退:用户删掉最后一轮再重发 + original := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "msg1"}, + map[string]any{"role": "assistant", "content": "reply1"}, + map[string]any{"role": "user", "content": "msg2"}, + map[string]any{"role": "assistant", "content": "reply2"}, + map[string]any{"role": "user", "content": "msg3"}, + }, + SessionContext: ctx, + } + + // 回退到 msg2 后,用新的 msg3 替代 + rollback := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "msg1"}, + map[string]any{"role": "assistant", "content": "reply1"}, + map[string]any{"role": "user", "content": "msg2"}, + map[string]any{"role": "assistant", "content": "reply2"}, + map[string]any{"role": "user", "content": "different msg3"}, + }, + SessionContext: ctx, + } + + hOrig := svc.GenerateSessionHash(original) + hRollback := svc.GenerateSessionHash(rollback) + require.NotEqual(t, hOrig, hRollback, "rollback with different last message should produce different hash") +} + +func TestGenerateSessionHash_MessageRollbackSameContent(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 回退后重新发送相同内容 → 相同 hash(合理的粘性恢复) + mk := func() *ParsedRequest { + return &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "msg1"}, + map[string]any{"role": "assistant", "content": "reply1"}, + map[string]any{"role": "user", "content": "msg2"}, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "rollback and resend same content should produce same hash") +} + +// ============ 相同 System、不同用户消息 ============ + +func TestGenerateSessionHash_SameSystemDifferentUsers(t *testing.T) { + svc := &GatewayService{} + + // 两个不同用户使用相同 system prompt 但发送不同消息 + user1 := &ParsedRequest{ + System: "You are a code reviewer.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "Review this Go code"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "vscode", + APIKeyID: 1, + }, + } + user2 := &ParsedRequest{ + System: "You are a code reviewer.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "Review this Python code"}, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "vscode", + APIKeyID: 2, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "different users with different messages should get different hashes") +} + +func TestGenerateSessionHash_SameSystemSameMessageDifferentContext(t *testing.T) { + svc := &GatewayService{} + + // 这是修复的核心场景:两个不同用户发送完全相同的 system + messages(如 "hello") + // 有了 SessionContext 后应该产生不同 hash + user1 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 10, + }, + } + user2 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "Mozilla/5.0", + APIKeyID: 20, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "CRITICAL: same system+messages but different users should get different hashes") +} + +// ============ SessionContext 各字段独立影响测试 ============ + +func TestGenerateSessionHash_SessionContext_IPDifference(t *testing.T) { + svc := &GatewayService{} + + base := func(ip string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: ip, + UserAgent: "same-ua", + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("1.1.1.1")) + h2 := svc.GenerateSessionHash(base("2.2.2.2")) + require.NotEqual(t, h1, h2, "different IP should produce different hash") +} + +func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) { + svc := &GatewayService{} + + base := func(ua string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: ua, + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("Mozilla/5.0")) + h2 := svc.GenerateSessionHash(base("curl/7.0")) + require.NotEqual(t, h1, h2, "different User-Agent should produce different hash") +} + +func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) { + svc := &GatewayService{} + + base := func(keyID int64) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "same-ua", + APIKeyID: keyID, + }, + } + } + + h1 := svc.GenerateSessionHash(base(1)) + h2 := svc.GenerateSessionHash(base(2)) + require.NotEqual(t, h1, h2, "different APIKeyID should produce different hash") +} + +// ============ 多用户并发相同消息场景 ============ + +func TestGenerateSessionHash_MultipleUsersSameFirstMessage(t *testing.T) { + svc := &GatewayService{} + + // 模拟 5 个不同用户同时发送 "hello" → 应该产生 5 个不同的 hash + hashes := make(map[string]bool) + for i := 0; i < 5; i++ { + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1." + string(rune('1'+i)), + UserAgent: "client-" + string(rune('A'+i)), + APIKeyID: int64(i + 1), + }, + } + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h) + require.False(t, hashes[h], "hash collision detected for user %d", i) + hashes[h] = true + } + require.Len(t, hashes, 5, "5 different users should produce 5 unique hashes") +} + +// ============ 连续会话粘性:多轮对话同一用户 ============ + +func TestGenerateSessionHash_SameUserGrowingConversation(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "browser", APIKeyID: 42} + + // 模拟同一用户的连续会话,每轮 hash 不同但同用户重试保持一致 + messages := []map[string]any{ + {"role": "user", "content": "msg1"}, + {"role": "assistant", "content": "reply1"}, + {"role": "user", "content": "msg2"}, + {"role": "assistant", "content": "reply2"}, + {"role": "user", "content": "msg3"}, + {"role": "assistant", "content": "reply3"}, + {"role": "user", "content": "msg4"}, + } + + prevHash := "" + for round := 1; round <= len(messages); round += 2 { + // 构建前 round 条消息 + msgs := make([]any, round) + for j := 0; j < round; j++ { + msgs[j] = messages[j] + } + parsed := &ParsedRequest{ + System: "System", + HasSystem: true, + Messages: msgs, + SessionContext: ctx, + } + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "round %d hash should not be empty", round) + + if prevHash != "" { + require.NotEqual(t, prevHash, h, "round %d hash should differ from previous round", round) + } + prevHash = h + + // 同一轮重试应该相同 + h2 := svc.GenerateSessionHash(parsed) + require.Equal(t, h, h2, "retry of round %d should produce same hash", round) + } +} + +// ============ 多轮消息内容结构化测试 ============ + +func TestGenerateSessionHash_MultipleUserMessages(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 5 条用户消息(无 assistant 回复) + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "first"}, + map[string]any{"role": "user", "content": "second"}, + map[string]any{"role": "user", "content": "third"}, + map[string]any{"role": "user", "content": "fourth"}, + map[string]any{"role": "user", "content": "fifth"}, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h) + + // 修改中间一条消息应该改变 hash + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "first"}, + map[string]any{"role": "user", "content": "CHANGED"}, + map[string]any{"role": "user", "content": "third"}, + map[string]any{"role": "user", "content": "fourth"}, + map[string]any{"role": "user", "content": "fifth"}, + }, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "changing any message should change the hash") +} + +func TestGenerateSessionHash_MessageOrderMatters(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "alpha"}, + map[string]any{"role": "user", "content": "beta"}, + }, + SessionContext: ctx, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "beta"}, + map[string]any{"role": "user", "content": "alpha"}, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "message order should affect the hash") +} + +// ============ 复杂内容格式测试 ============ + +func TestGenerateSessionHash_StructuredContent(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 结构化 content(数组形式) + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "Look at this"}, + map[string]any{"type": "text", "text": "And this too"}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "structured content should produce a hash") +} + +func TestGenerateSessionHash_ArraySystemPrompt(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 数组格式的 system prompt + parsed := &ParsedRequest{ + System: []any{ + map[string]any{"type": "text", "text": "You are a helpful assistant."}, + map[string]any{"type": "text", "text": "Be concise."}, + }, + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "array system prompt should produce a hash") +} + +// ============ SessionContext 与 cache_control 优先级 ============ + +func TestGenerateSessionHash_CacheControlOverridesSessionContext(t *testing.T) { + svc := &GatewayService{} + + // 当有 cache_control: ephemeral 时,使用第 2 级优先级 + // SessionContext 不应影响结果 + parsed1 := &ParsedRequest{ + System: []any{ + map[string]any{ + "type": "text", + "text": "You are a tool-specific assistant.", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + }, + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "ua1", + APIKeyID: 100, + }, + } + parsed2 := &ParsedRequest{ + System: []any{ + map[string]any{ + "type": "text", + "text": "You are a tool-specific assistant.", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + }, + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "ua2", + APIKeyID: 200, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.Equal(t, h1, h2, "cache_control ephemeral has higher priority, SessionContext should not affect result") +} + +// ============ 边界情况 ============ + +func TestGenerateSessionHash_EmptyMessages(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + Messages: []any{}, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "test", + APIKeyID: 1, + }, + } + + // 空 messages + 只有 SessionContext 时,combined.Len() > 0 因为有 context 写入 + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "empty messages with SessionContext should still produce a hash from context") +} + +func TestGenerateSessionHash_EmptyMessagesNoContext(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + Messages: []any{}, + } + + h := svc.GenerateSessionHash(parsed) + require.Empty(t, h, "empty messages without SessionContext should produce empty hash") +} + +func TestGenerateSessionHash_SessionContextWithEmptyFields(t *testing.T) { + svc := &GatewayService{} + + // SessionContext 字段为空字符串和零值时仍应影响 hash + withEmptyCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "", + UserAgent: "", + APIKeyID: 0, + }, + } + withoutCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + } + + h1 := svc.GenerateSessionHash(withEmptyCtx) + h2 := svc.GenerateSessionHash(withoutCtx) + // 有 SessionContext(即使字段为空)仍然会写入分隔符 "::" 等 + require.NotEqual(t, h1, h2, "empty-field SessionContext should still differ from nil SessionContext") +} + +// ============ 长对话历史测试 ============ + +func TestGenerateSessionHash_LongConversation(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 构建 20 轮对话 + messages := make([]any, 0, 40) + for i := 0; i < 20; i++ { + messages = append(messages, map[string]any{ + "role": "user", + "content": "user message " + string(rune('A'+i)), + }) + messages = append(messages, map[string]any{ + "role": "assistant", + "content": "assistant reply " + string(rune('A'+i)), + }) + } + + parsed := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: messages, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h) + + // 再加一轮应该不同 + moreMessages := make([]any, len(messages)+2) + copy(moreMessages, messages) + moreMessages[len(messages)] = map[string]any{"role": "user", "content": "one more"} + moreMessages[len(messages)+1] = map[string]any{"role": "assistant", "content": "ok"} + + parsed2 := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: moreMessages, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "adding more messages to long conversation should change hash") +} + +// ============ Gemini 原生格式 session hash 测试 ============ + +func TestGenerateSessionHash_GeminiContentsProducesHash(t *testing.T) { + svc := &GatewayService{} + + // Gemini 格式: contents[].parts[].text + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello from Gemini"}, + }, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "1.2.3.4", + UserAgent: "gemini-cli", + APIKeyID: 1, + }, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "Gemini contents with parts should produce a non-empty hash") +} + +func TestGenerateSessionHash_GeminiDifferentContentsDifferentHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello"}, + }, + }, + }, + SessionContext: ctx, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Goodbye"}, + }, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "different Gemini contents should produce different hashes") +} + +func TestGenerateSessionHash_GeminiSameContentsSameHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello"}, + }, + }, + map[string]any{ + "role": "model", + "parts": []any{ + map[string]any{"text": "Hi there!"}, + }, + }, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same Gemini contents should produce identical hash") +} + +func TestGenerateSessionHash_GeminiMultiTurnHashChanges(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + round1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + + round2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + map[string]any{ + "role": "model", + "parts": []any{map[string]any{"text": "Hi!"}}, + }, + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "How are you?"}}, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(round1) + h2 := svc.GenerateSessionHash(round2) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "Gemini multi-turn should produce different hashes per round") +} + +func TestGenerateSessionHash_GeminiDifferentUsersSameContentDifferentHash(t *testing.T) { + svc := &GatewayService{} + + // 核心场景:两个不同用户发送相同 Gemini 格式消息应得到不同 hash + user1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "gemini-cli", + APIKeyID: 10, + }, + } + user2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "gemini-cli", + APIKeyID: 20, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "CRITICAL: different Gemini users with same content must get different hashes") +} + +func TestGenerateSessionHash_GeminiSystemInstructionAffectsHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // systemInstruction 经 ParseGatewayRequest 解析后存入 parsed.System + withSys := &ParsedRequest{ + System: []any{ + map[string]any{"text": "You are a coding assistant."}, + }, + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + withoutSys := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(withSys) + h2 := svc.GenerateSessionHash(withoutSys) + require.NotEqual(t, h1, h2, "systemInstruction should affect the hash") +} + +func TestGenerateSessionHash_GeminiMultiPartMessage(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // 多 parts 的消息 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Part 1"}, + map[string]any{"text": "Part 2"}, + map[string]any{"text": "Part 3"}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "multi-part Gemini message should produce a hash") + + // 不同内容的多 parts + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Part 1"}, + map[string]any{"text": "CHANGED"}, + map[string]any{"text": "Part 3"}, + }, + }, + }, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "changing a part should change the hash") +} + +func TestGenerateSessionHash_GeminiNonTextPartsIgnored(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // 含非 text 类型 parts(如 inline_data),应被跳过但不报错 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Describe this image"}, + map[string]any{"inline_data": map[string]any{"mime_type": "image/png", "data": "base64..."}}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "Gemini message with mixed parts should still produce a hash from text parts") +} + +func TestGenerateSessionHash_GeminiMultiTurnHashNotSticky(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "10.0.0.1", UserAgent: "gemini-cli", APIKeyID: 42} + + // 模拟同一 Gemini 会话的三轮请求,每轮 contents 累积增长。 + // 验证预期行为:每轮 hash 都不同,即 GenerateSessionHash 不具备跨轮粘性。 + // 这是 by-design 的——Gemini 的跨轮粘性由 Digest Fallback(BuildGeminiDigestChain)负责。 + round1Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]} + ] + }`) + round2Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "func hello() {}"}]}, + {"role": "user", "parts": [{"text": "Add error handling"}]} + ] + }`) + round3Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "func hello() {}"}]}, + {"role": "user", "parts": [{"text": "Add error handling"}]}, + {"role": "model", "parts": [{"text": "func hello() error { return nil }"}]}, + {"role": "user", "parts": [{"text": "Now add tests"}]} + ] + }`) + + hashes := make([]string, 3) + for i, body := range [][]byte{round1Body, round2Body, round3Body} { + parsed, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed.SessionContext = ctx + hashes[i] = svc.GenerateSessionHash(parsed) + require.NotEmpty(t, hashes[i], "round %d hash should not be empty", i+1) + } + + // 每轮 hash 都不同——这是预期行为 + require.NotEqual(t, hashes[0], hashes[1], "round 1 vs 2 hash should differ (contents grow)") + require.NotEqual(t, hashes[1], hashes[2], "round 2 vs 3 hash should differ (contents grow)") + require.NotEqual(t, hashes[0], hashes[2], "round 1 vs 3 hash should differ") + + // 同一轮重试应产生相同 hash + parsed1Again, err := ParseGatewayRequest(round2Body, "gemini") + require.NoError(t, err) + parsed1Again.SessionContext = ctx + h2Again := svc.GenerateSessionHash(parsed1Again) + require.Equal(t, hashes[1], h2Again, "retry of same round should produce same hash") +} + +func TestGenerateSessionHash_GeminiEndToEnd(t *testing.T) { + svc := &GatewayService{} + + // 端到端测试:模拟 ParseGatewayRequest + GenerateSessionHash 完整流程 + body := []byte(`{ + "model": "gemini-2.5-pro", + "systemInstruction": { + "parts": [{"text": "You are a coding assistant."}] + }, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "Here is a function..."}]}, + {"role": "user", "parts": [{"text": "Now add error handling"}]} + ] + }`) + + parsed, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed.SessionContext = &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "gemini-cli/1.0", + APIKeyID: 42, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "end-to-end Gemini flow should produce a hash") + + // 同一请求再次解析应产生相同 hash + parsed2, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed2.SessionContext = &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "gemini-cli/1.0", + APIKeyID: 42, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.Equal(t, h, h2, "same request should produce same hash") + + // 不同用户发送相同请求应产生不同 hash + parsed3, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed3.SessionContext = &SessionContext{ + ClientIP: "10.0.0.2", + UserAgent: "gemini-cli/1.0", + APIKeyID: 99, + } + + h3 := svc.GenerateSessionHash(parsed3) + require.NotEqual(t, h, h3, "different user with same Gemini request should get different hash") +} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 1302047a..6990caca 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -26,6 +26,15 @@ type Group struct { ImagePrice2K *float64 ImagePrice4K *float64 + // Sora 按次计费配置(阶段 1) + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + + // Sora 存储配额 + SoraStorageQuotaBytes int64 + // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 @@ -45,6 +54,9 @@ type Group struct { // 可选值: claude, gemini_text, gemini_image SupportedModelScopes []string + // 分组排序 + SortOrder int + CreatedAt time.Time UpdatedAt time.Time @@ -92,6 +104,18 @@ func (g *Group) GetImagePrice(imageSize string) *float64 { } } +// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540) +func (g *Group) GetSoraImagePrice(imageSize string) *float64 { + switch imageSize { + case "360": + return g.SoraImagePrice360 + case "540": + return g.SoraImagePrice540 + default: + return g.SoraImagePrice360 + } +} + // IsGroupContextValid reports whether a group from context has the fields required for routing decisions. func IsGroupContextValid(group *Group) bool { if group == nil { diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index a2bf2073..22a67eda 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -33,6 +33,14 @@ type GroupRepository interface { GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) // BindAccountsToGroup 将多个账号绑定到指定分组 BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error + // UpdateSortOrders 批量更新分组排序 + UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error +} + +// GroupSortOrderUpdate 分组排序更新 +type GroupSortOrderUpdate struct { + ID int64 `json:"id"` + SortOrder int `json:"sort_order"` } // CreateGroupRequest 创建分组请求 diff --git a/backend/internal/service/idempotency.go b/backend/internal/service/idempotency.go new file mode 100644 index 00000000..2a86bd60 --- /dev/null +++ b/backend/internal/service/idempotency.go @@ -0,0 +1,471 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" +) + +const ( + IdempotencyStatusProcessing = "processing" + IdempotencyStatusSucceeded = "succeeded" + IdempotencyStatusFailedRetryable = "failed_retryable" +) + +var ( + ErrIdempotencyKeyRequired = infraerrors.BadRequest("IDEMPOTENCY_KEY_REQUIRED", "idempotency key is required") + ErrIdempotencyKeyInvalid = infraerrors.BadRequest("IDEMPOTENCY_KEY_INVALID", "idempotency key is invalid") + ErrIdempotencyKeyConflict = infraerrors.Conflict("IDEMPOTENCY_KEY_CONFLICT", "idempotency key reused with different payload") + ErrIdempotencyInProgress = infraerrors.Conflict("IDEMPOTENCY_IN_PROGRESS", "idempotent request is still processing") + ErrIdempotencyRetryBackoff = infraerrors.Conflict("IDEMPOTENCY_RETRY_BACKOFF", "idempotent request is in retry backoff window") + ErrIdempotencyStoreUnavail = infraerrors.ServiceUnavailable("IDEMPOTENCY_STORE_UNAVAILABLE", "idempotency store unavailable") + ErrIdempotencyInvalidPayload = infraerrors.BadRequest("IDEMPOTENCY_PAYLOAD_INVALID", "failed to normalize request payload") +) + +type IdempotencyRecord struct { + ID int64 + Scope string + IdempotencyKeyHash string + RequestFingerprint string + Status string + ResponseStatus *int + ResponseBody *string + ErrorReason *string + LockedUntil *time.Time + ExpiresAt time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + +type IdempotencyRepository interface { + CreateProcessing(ctx context.Context, record *IdempotencyRecord) (bool, error) + GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*IdempotencyRecord, error) + TryReclaim(ctx context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) + ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) + MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error + MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error + DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) +} + +type IdempotencyConfig struct { + DefaultTTL time.Duration + SystemOperationTTL time.Duration + ProcessingTimeout time.Duration + FailedRetryBackoff time.Duration + MaxStoredResponseLen int + ObserveOnly bool +} + +func DefaultIdempotencyConfig() IdempotencyConfig { + return IdempotencyConfig{ + DefaultTTL: 24 * time.Hour, + SystemOperationTTL: 1 * time.Hour, + ProcessingTimeout: 30 * time.Second, + FailedRetryBackoff: 5 * time.Second, + MaxStoredResponseLen: 64 * 1024, + ObserveOnly: true, // 默认先观察再强制,避免老客户端立刻中断 + } +} + +type IdempotencyExecuteOptions struct { + Scope string + ActorScope string + Method string + Route string + IdempotencyKey string + Payload any + TTL time.Duration + RequireKey bool +} + +type IdempotencyExecuteResult struct { + Data any + Replayed bool +} + +type IdempotencyCoordinator struct { + repo IdempotencyRepository + cfg IdempotencyConfig +} + +var ( + defaultIdempotencyMu sync.RWMutex + defaultIdempotencySvc *IdempotencyCoordinator +) + +func SetDefaultIdempotencyCoordinator(svc *IdempotencyCoordinator) { + defaultIdempotencyMu.Lock() + defaultIdempotencySvc = svc + defaultIdempotencyMu.Unlock() +} + +func DefaultIdempotencyCoordinator() *IdempotencyCoordinator { + defaultIdempotencyMu.RLock() + defer defaultIdempotencyMu.RUnlock() + return defaultIdempotencySvc +} + +func DefaultWriteIdempotencyTTL() time.Duration { + defaultTTL := DefaultIdempotencyConfig().DefaultTTL + if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.DefaultTTL > 0 { + return coordinator.cfg.DefaultTTL + } + return defaultTTL +} + +func DefaultSystemOperationIdempotencyTTL() time.Duration { + defaultTTL := DefaultIdempotencyConfig().SystemOperationTTL + if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.SystemOperationTTL > 0 { + return coordinator.cfg.SystemOperationTTL + } + return defaultTTL +} + +func NewIdempotencyCoordinator(repo IdempotencyRepository, cfg IdempotencyConfig) *IdempotencyCoordinator { + return &IdempotencyCoordinator{ + repo: repo, + cfg: cfg, + } +} + +func NormalizeIdempotencyKey(raw string) (string, error) { + key := strings.TrimSpace(raw) + if key == "" { + return "", nil + } + if len(key) > 128 { + return "", ErrIdempotencyKeyInvalid + } + for _, r := range key { + if r < 33 || r > 126 { + return "", ErrIdempotencyKeyInvalid + } + } + return key, nil +} + +func HashIdempotencyKey(key string) string { + sum := sha256.Sum256([]byte(key)) + return hex.EncodeToString(sum[:]) +} + +func BuildIdempotencyFingerprint(method, route, actorScope string, payload any) (string, error) { + if method == "" { + method = "POST" + } + if route == "" { + route = "/" + } + if actorScope == "" { + actorScope = "anonymous" + } + + raw, err := json.Marshal(payload) + if err != nil { + return "", ErrIdempotencyInvalidPayload.WithCause(err) + } + sum := sha256.Sum256([]byte( + strings.ToUpper(method) + "\n" + route + "\n" + actorScope + "\n" + string(raw), + )) + return hex.EncodeToString(sum[:]), nil +} + +func RetryAfterSecondsFromError(err error) int { + appErr := new(infraerrors.ApplicationError) + if !errors.As(err, &appErr) || appErr == nil || appErr.Metadata == nil { + return 0 + } + v := strings.TrimSpace(appErr.Metadata["retry_after"]) + if v == "" { + return 0 + } + seconds, convErr := strconv.Atoi(v) + if convErr != nil || seconds <= 0 { + return 0 + } + return seconds +} + +func (c *IdempotencyCoordinator) Execute( + ctx context.Context, + opts IdempotencyExecuteOptions, + execute func(context.Context) (any, error), +) (*IdempotencyExecuteResult, error) { + if execute == nil { + return nil, infraerrors.InternalServer("IDEMPOTENCY_EXECUTOR_NIL", "idempotency executor is nil") + } + + key, err := NormalizeIdempotencyKey(opts.IdempotencyKey) + if err != nil { + return nil, err + } + if key == "" { + if opts.RequireKey && !c.cfg.ObserveOnly { + return nil, ErrIdempotencyKeyRequired + } + data, execErr := execute(ctx) + if execErr != nil { + return nil, execErr + } + return &IdempotencyExecuteResult{Data: data}, nil + } + if c.repo == nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "repo_nil") + return nil, ErrIdempotencyStoreUnavail + } + + if opts.Scope == "" { + return nil, infraerrors.BadRequest("IDEMPOTENCY_SCOPE_REQUIRED", "idempotency scope is required") + } + + fingerprint, err := BuildIdempotencyFingerprint(opts.Method, opts.Route, opts.ActorScope, opts.Payload) + if err != nil { + return nil, err + } + + ttl := opts.TTL + if ttl <= 0 { + ttl = c.cfg.DefaultTTL + } + now := time.Now() + expiresAt := now.Add(ttl) + lockedUntil := now.Add(c.cfg.ProcessingTimeout) + keyHash := HashIdempotencyKey(key) + + record := &IdempotencyRecord{ + Scope: opts.Scope, + IdempotencyKeyHash: keyHash, + RequestFingerprint: fingerprint, + Status: IdempotencyStatusProcessing, + LockedUntil: &lockedUntil, + ExpiresAt: expiresAt, + } + + owner, err := c.repo.CreateProcessing(ctx, record) + if err != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "create_processing_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "create_processing", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(err) + } + if owner { + recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "new_claim"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "none->processing", false, map[string]string{ + "claim_mode": "new", + }) + } + if !owner { + existing, getErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash) + if getErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "get_existing", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(getErr) + } + if existing == nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "missing_existing", + }) + return nil, ErrIdempotencyStoreUnavail + } + if existing.RequestFingerprint != fingerprint { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil) + return nil, ErrIdempotencyKeyConflict + } + reclaimedByExpired := false + if !existing.ExpiresAt.After(now) { + taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, existing.Status, now, lockedUntil, expiresAt) + if reclaimErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_expired_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->store_unavailable", false, map[string]string{ + "operation": "try_reclaim_expired", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr) + } + if taken { + reclaimedByExpired = true + recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "expired_reclaim"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->processing", false, map[string]string{ + "claim_mode": "expired_reclaim", + }) + record.ID = existing.ID + } else { + latest, latestErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash) + if latestErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_after_expired_reclaim_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "get_existing_after_expired_reclaim", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(latestErr) + } + if latest == nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing_after_expired_reclaim") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "missing_existing_after_expired_reclaim", + }) + return nil, ErrIdempotencyStoreUnavail + } + if latest.RequestFingerprint != fingerprint { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil) + return nil, ErrIdempotencyKeyConflict + } + existing = latest + } + } + + if !reclaimedByExpired { + switch existing.Status { + case IdempotencyStatusSucceeded: + data, parseErr := c.decodeStoredResponse(existing.ResponseBody) + if parseErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "decode_stored_response_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->store_unavailable", false, map[string]string{ + "operation": "decode_stored_response", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(parseErr) + } + recordIdempotencyReplay(opts.Route, opts.Scope, nil) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->replayed", true, nil) + return &IdempotencyExecuteResult{Data: data, Replayed: true}, nil + case IdempotencyStatusProcessing: + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "in_progress"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->conflict", false, nil) + return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now) + case IdempotencyStatusFailedRetryable: + if existing.LockedUntil != nil && existing.LockedUntil.After(now) { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "retry_backoff"}) + recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->retry_backoff_conflict", false, nil) + return nil, c.conflictWithRetryAfter(ErrIdempotencyRetryBackoff, existing.LockedUntil, now) + } + taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, IdempotencyStatusFailedRetryable, now, lockedUntil, expiresAt) + if reclaimErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->store_unavailable", false, map[string]string{ + "operation": "try_reclaim", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr) + } + if !taken { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "reclaim_race"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->conflict", false, map[string]string{ + "conflict": "reclaim_race", + }) + return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now) + } + recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "reclaim"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->processing", false, map[string]string{ + "claim_mode": "reclaim", + }) + record.ID = existing.ID + default: + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "unexpected_status"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->conflict", false, map[string]string{ + "status": existing.Status, + }) + return nil, ErrIdempotencyKeyConflict + } + } + } + + if record.ID == 0 { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "record_id_missing") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "record_id_missing", + }) + return nil, ErrIdempotencyStoreUnavail + } + + execStart := time.Now() + defer func() { + recordIdempotencyProcessingDuration(opts.Route, opts.Scope, time.Since(execStart), nil) + }() + + data, execErr := execute(ctx) + if execErr != nil { + backoffUntil := time.Now().Add(c.cfg.FailedRetryBackoff) + reason := infraerrors.Reason(execErr) + if reason == "" { + reason = "EXECUTION_FAILED" + } + recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->failed_retryable", false, map[string]string{ + "reason": reason, + }) + if markErr := c.repo.MarkFailedRetryable(ctx, record.ID, reason, backoffUntil, expiresAt); markErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_failed_retryable_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "mark_failed_retryable", + }) + } + return nil, execErr + } + + storedBody, marshalErr := c.marshalStoredResponse(data) + if marshalErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "marshal_response_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "marshal_response", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(marshalErr) + } + if markErr := c.repo.MarkSucceeded(ctx, record.ID, 200, storedBody, expiresAt); markErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_succeeded_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "mark_succeeded", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(markErr) + } + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->succeeded", false, nil) + + return &IdempotencyExecuteResult{Data: data}, nil +} + +func (c *IdempotencyCoordinator) conflictWithRetryAfter(base *infraerrors.ApplicationError, lockedUntil *time.Time, now time.Time) error { + if lockedUntil == nil { + return base + } + sec := int(lockedUntil.Sub(now).Seconds()) + if sec <= 0 { + sec = 1 + } + return base.WithMetadata(map[string]string{"retry_after": strconv.Itoa(sec)}) +} + +func (c *IdempotencyCoordinator) marshalStoredResponse(data any) (string, error) { + raw, err := json.Marshal(data) + if err != nil { + return "", err + } + redacted := logredact.RedactText(string(raw)) + if c.cfg.MaxStoredResponseLen > 0 && len(redacted) > c.cfg.MaxStoredResponseLen { + redacted = redacted[:c.cfg.MaxStoredResponseLen] + "...(truncated)" + } + return redacted, nil +} + +func (c *IdempotencyCoordinator) decodeStoredResponse(stored *string) (any, error) { + if stored == nil || strings.TrimSpace(*stored) == "" { + return map[string]any{}, nil + } + var out any + if err := json.Unmarshal([]byte(*stored), &out); err != nil { + return nil, fmt.Errorf("decode stored response: %w", err) + } + return out, nil +} diff --git a/backend/internal/service/idempotency_cleanup_service.go b/backend/internal/service/idempotency_cleanup_service.go new file mode 100644 index 00000000..aaf6949a --- /dev/null +++ b/backend/internal/service/idempotency_cleanup_service.go @@ -0,0 +1,91 @@ +package service + +import ( + "context" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// IdempotencyCleanupService 定期清理已过期的幂等记录,避免表无限增长。 +type IdempotencyCleanupService struct { + repo IdempotencyRepository + interval time.Duration + batch int + + startOnce sync.Once + stopOnce sync.Once + stopCh chan struct{} +} + +func NewIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService { + interval := 60 * time.Second + batch := 500 + if cfg != nil { + if cfg.Idempotency.CleanupIntervalSeconds > 0 { + interval = time.Duration(cfg.Idempotency.CleanupIntervalSeconds) * time.Second + } + if cfg.Idempotency.CleanupBatchSize > 0 { + batch = cfg.Idempotency.CleanupBatchSize + } + } + return &IdempotencyCleanupService{ + repo: repo, + interval: interval, + batch: batch, + stopCh: make(chan struct{}), + } +} + +func (s *IdempotencyCleanupService) Start() { + if s == nil || s.repo == nil { + return + } + s.startOnce.Do(func() { + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] started interval=%s batch=%d", s.interval, s.batch) + go s.runLoop() + }) +} + +func (s *IdempotencyCleanupService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + close(s.stopCh) + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] stopped") + }) +} + +func (s *IdempotencyCleanupService) runLoop() { + ticker := time.NewTicker(s.interval) + defer ticker.Stop() + + // 启动后先清理一轮,防止重启后积压。 + s.cleanupOnce() + + for { + select { + case <-ticker.C: + s.cleanupOnce() + case <-s.stopCh: + return + } + } +} + +func (s *IdempotencyCleanupService) cleanupOnce() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + deleted, err := s.repo.DeleteExpired(ctx, time.Now(), s.batch) + if err != nil { + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleanup failed err=%v", err) + return + } + if deleted > 0 { + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleaned expired records count=%d", deleted) + } +} diff --git a/backend/internal/service/idempotency_cleanup_service_test.go b/backend/internal/service/idempotency_cleanup_service_test.go new file mode 100644 index 00000000..556ff364 --- /dev/null +++ b/backend/internal/service/idempotency_cleanup_service_test.go @@ -0,0 +1,69 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type idempotencyCleanupRepoStub struct { + deleteCalls int + lastLimit int + deleteErr error +} + +func (r *idempotencyCleanupRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return false, nil +} +func (r *idempotencyCleanupRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return nil, nil +} +func (r *idempotencyCleanupRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, nil +} +func (r *idempotencyCleanupRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, nil +} +func (r *idempotencyCleanupRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return nil +} +func (r *idempotencyCleanupRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return nil +} +func (r *idempotencyCleanupRepoStub) DeleteExpired(_ context.Context, _ time.Time, limit int) (int64, error) { + r.deleteCalls++ + r.lastLimit = limit + if r.deleteErr != nil { + return 0, r.deleteErr + } + return 1, nil +} + +func TestNewIdempotencyCleanupService_UsesConfig(t *testing.T) { + repo := &idempotencyCleanupRepoStub{} + cfg := &config.Config{ + Idempotency: config.IdempotencyConfig{ + CleanupIntervalSeconds: 7, + CleanupBatchSize: 321, + }, + } + svc := NewIdempotencyCleanupService(repo, cfg) + require.Equal(t, 7*time.Second, svc.interval) + require.Equal(t, 321, svc.batch) +} + +func TestIdempotencyCleanupService_CleanupOnce(t *testing.T) { + repo := &idempotencyCleanupRepoStub{} + svc := NewIdempotencyCleanupService(repo, &config.Config{ + Idempotency: config.IdempotencyConfig{ + CleanupBatchSize: 99, + }, + }) + + svc.cleanupOnce() + require.Equal(t, 1, repo.deleteCalls) + require.Equal(t, 99, repo.lastLimit) +} diff --git a/backend/internal/service/idempotency_observability.go b/backend/internal/service/idempotency_observability.go new file mode 100644 index 00000000..f1bf2df2 --- /dev/null +++ b/backend/internal/service/idempotency_observability.go @@ -0,0 +1,171 @@ +package service + +import ( + "sort" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// IdempotencyMetricsSnapshot 提供幂等核心指标快照(进程内累计)。 +type IdempotencyMetricsSnapshot struct { + ClaimTotal uint64 `json:"claim_total"` + ReplayTotal uint64 `json:"replay_total"` + ConflictTotal uint64 `json:"conflict_total"` + RetryBackoffTotal uint64 `json:"retry_backoff_total"` + ProcessingDurationCount uint64 `json:"processing_duration_count"` + ProcessingDurationTotalMs float64 `json:"processing_duration_total_ms"` + StoreUnavailableTotal uint64 `json:"store_unavailable_total"` +} + +type idempotencyMetrics struct { + claimTotal atomic.Uint64 + replayTotal atomic.Uint64 + conflictTotal atomic.Uint64 + retryBackoffTotal atomic.Uint64 + processingDurationCount atomic.Uint64 + processingDurationMicros atomic.Uint64 + storeUnavailableTotal atomic.Uint64 +} + +var defaultIdempotencyMetrics idempotencyMetrics + +// GetIdempotencyMetricsSnapshot 返回当前幂等指标快照。 +func GetIdempotencyMetricsSnapshot() IdempotencyMetricsSnapshot { + totalMicros := defaultIdempotencyMetrics.processingDurationMicros.Load() + return IdempotencyMetricsSnapshot{ + ClaimTotal: defaultIdempotencyMetrics.claimTotal.Load(), + ReplayTotal: defaultIdempotencyMetrics.replayTotal.Load(), + ConflictTotal: defaultIdempotencyMetrics.conflictTotal.Load(), + RetryBackoffTotal: defaultIdempotencyMetrics.retryBackoffTotal.Load(), + ProcessingDurationCount: defaultIdempotencyMetrics.processingDurationCount.Load(), + ProcessingDurationTotalMs: float64(totalMicros) / 1000.0, + StoreUnavailableTotal: defaultIdempotencyMetrics.storeUnavailableTotal.Load(), + } +} + +func recordIdempotencyClaim(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.claimTotal.Add(1) + logIdempotencyMetric("idempotency_claim_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyReplay(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.replayTotal.Add(1) + logIdempotencyMetric("idempotency_replay_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyConflict(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.conflictTotal.Add(1) + logIdempotencyMetric("idempotency_conflict_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyRetryBackoff(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.retryBackoffTotal.Add(1) + logIdempotencyMetric("idempotency_retry_backoff_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyProcessingDuration(endpoint, scope string, duration time.Duration, attrs map[string]string) { + if duration < 0 { + duration = 0 + } + defaultIdempotencyMetrics.processingDurationCount.Add(1) + defaultIdempotencyMetrics.processingDurationMicros.Add(uint64(duration.Microseconds())) + logIdempotencyMetric("idempotency_processing_duration_ms", endpoint, scope, strconv.FormatFloat(duration.Seconds()*1000, 'f', 3, 64), attrs) +} + +// RecordIdempotencyStoreUnavailable 记录幂等存储不可用事件(用于降级路径观测)。 +func RecordIdempotencyStoreUnavailable(endpoint, scope, strategy string) { + defaultIdempotencyMetrics.storeUnavailableTotal.Add(1) + attrs := map[string]string{} + if strategy != "" { + attrs["strategy"] = strategy + } + logIdempotencyMetric("idempotency_store_unavailable_total", endpoint, scope, "1", attrs) +} + +func logIdempotencyAudit(endpoint, scope, keyHash, stateTransition string, replayed bool, attrs map[string]string) { + var b strings.Builder + builderWriteString(&b, "[IdempotencyAudit]") + builderWriteString(&b, " endpoint=") + builderWriteString(&b, safeAuditField(endpoint)) + builderWriteString(&b, " scope=") + builderWriteString(&b, safeAuditField(scope)) + builderWriteString(&b, " key_hash=") + builderWriteString(&b, safeAuditField(keyHash)) + builderWriteString(&b, " state_transition=") + builderWriteString(&b, safeAuditField(stateTransition)) + builderWriteString(&b, " replayed=") + builderWriteString(&b, strconv.FormatBool(replayed)) + if len(attrs) > 0 { + appendSortedAttrs(&b, attrs) + } + logger.LegacyPrintf("service.idempotency", "%s", b.String()) +} + +func logIdempotencyMetric(name, endpoint, scope, value string, attrs map[string]string) { + var b strings.Builder + builderWriteString(&b, "[IdempotencyMetric]") + builderWriteString(&b, " name=") + builderWriteString(&b, safeAuditField(name)) + builderWriteString(&b, " endpoint=") + builderWriteString(&b, safeAuditField(endpoint)) + builderWriteString(&b, " scope=") + builderWriteString(&b, safeAuditField(scope)) + builderWriteString(&b, " value=") + builderWriteString(&b, safeAuditField(value)) + if len(attrs) > 0 { + appendSortedAttrs(&b, attrs) + } + logger.LegacyPrintf("service.idempotency", "%s", b.String()) +} + +func appendSortedAttrs(builder *strings.Builder, attrs map[string]string) { + if len(attrs) == 0 { + return + } + keys := make([]string, 0, len(attrs)) + for k := range attrs { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + builderWriteByte(builder, ' ') + builderWriteString(builder, k) + builderWriteByte(builder, '=') + builderWriteString(builder, safeAuditField(attrs[k])) + } +} + +func safeAuditField(v string) string { + value := strings.TrimSpace(v) + if value == "" { + return "-" + } + // 日志按 key=value 输出,替换空白避免解析歧义。 + value = strings.ReplaceAll(value, "\n", "_") + value = strings.ReplaceAll(value, "\r", "_") + value = strings.ReplaceAll(value, "\t", "_") + value = strings.ReplaceAll(value, " ", "_") + return value +} + +func resetIdempotencyMetricsForTest() { + defaultIdempotencyMetrics.claimTotal.Store(0) + defaultIdempotencyMetrics.replayTotal.Store(0) + defaultIdempotencyMetrics.conflictTotal.Store(0) + defaultIdempotencyMetrics.retryBackoffTotal.Store(0) + defaultIdempotencyMetrics.processingDurationCount.Store(0) + defaultIdempotencyMetrics.processingDurationMicros.Store(0) + defaultIdempotencyMetrics.storeUnavailableTotal.Store(0) +} + +func builderWriteString(builder *strings.Builder, value string) { + _, _ = builder.WriteString(value) +} + +func builderWriteByte(builder *strings.Builder, value byte) { + _ = builder.WriteByte(value) +} diff --git a/backend/internal/service/idempotency_test.go b/backend/internal/service/idempotency_test.go new file mode 100644 index 00000000..6ff75d1c --- /dev/null +++ b/backend/internal/service/idempotency_test.go @@ -0,0 +1,805 @@ +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type inMemoryIdempotencyRepo struct { + mu sync.Mutex + nextID int64 + data map[string]*IdempotencyRecord +} + +func newInMemoryIdempotencyRepo() *inMemoryIdempotencyRepo { + return &inMemoryIdempotencyRepo{ + nextID: 1, + data: make(map[string]*IdempotencyRecord), + } +} + +func (r *inMemoryIdempotencyRepo) key(scope, hash string) string { + return scope + "|" + hash +} + +func cloneRecord(in *IdempotencyRecord) *IdempotencyRecord { + if in == nil { + return nil + } + out := *in + if in.ResponseStatus != nil { + v := *in.ResponseStatus + out.ResponseStatus = &v + } + if in.ResponseBody != nil { + v := *in.ResponseBody + out.ResponseBody = &v + } + if in.ErrorReason != nil { + v := *in.ErrorReason + out.ErrorReason = &v + } + if in.LockedUntil != nil { + v := *in.LockedUntil + out.LockedUntil = &v + } + return &out +} + +func (r *inMemoryIdempotencyRepo) CreateProcessing(_ context.Context, record *IdempotencyRecord) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + k := r.key(record.Scope, record.IdempotencyKeyHash) + if _, ok := r.data[k]; ok { + return false, nil + } + rec := cloneRecord(record) + rec.ID = r.nextID + rec.CreatedAt = time.Now() + rec.UpdatedAt = rec.CreatedAt + r.nextID++ + r.data[k] = rec + record.ID = rec.ID + record.CreatedAt = rec.CreatedAt + record.UpdatedAt = rec.UpdatedAt + return true, nil +} + +func (r *inMemoryIdempotencyRepo) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*IdempotencyRecord, error) { + r.mu.Lock() + defer r.mu.Unlock() + return cloneRecord(r.data[r.key(scope, keyHash)]), nil +} + +func (r *inMemoryIdempotencyRepo) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != fromStatus { + return false, nil + } + if rec.LockedUntil != nil && rec.LockedUntil.After(now) { + return false, nil + } + rec.Status = IdempotencyStatusProcessing + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.ErrorReason = nil + rec.UpdatedAt = time.Now() + return true, nil + } + return false, nil +} + +func (r *inMemoryIdempotencyRepo) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint { + return false, nil + } + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.UpdatedAt = time.Now() + return true, nil + } + return false, nil +} + +func (r *inMemoryIdempotencyRepo) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = IdempotencyStatusSucceeded + rec.LockedUntil = nil + rec.ExpiresAt = expiresAt + rec.UpdatedAt = time.Now() + rec.ErrorReason = nil + rec.ResponseStatus = &responseStatus + rec.ResponseBody = &responseBody + return nil + } + return errors.New("record not found") +} + +func (r *inMemoryIdempotencyRepo) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = IdempotencyStatusFailedRetryable + rec.LockedUntil = &lockedUntil + rec.ExpiresAt = expiresAt + rec.UpdatedAt = time.Now() + rec.ErrorReason = &errorReason + return nil + } + return errors.New("record not found") +} + +func (r *inMemoryIdempotencyRepo) DeleteExpired(_ context.Context, now time.Time, _ int) (int64, error) { + r.mu.Lock() + defer r.mu.Unlock() + var deleted int64 + for k, rec := range r.data { + if !rec.ExpiresAt.After(now) { + delete(r.data, k) + deleted++ + } + } + return deleted, nil +} + +func TestIdempotencyCoordinator_RequireKey(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + cfg.ObserveOnly = false + coordinator := NewIdempotencyCoordinator(repo, cfg) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "admin:1", + RequireKey: true, + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyRequired)) +} + +func TestIdempotencyCoordinator_ReplaySucceededResult(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + coordinator := NewIdempotencyCoordinator(repo, cfg) + + execCount := 0 + exec := func(ctx context.Context) (any, error) { + execCount++ + return map[string]any{"count": execCount}, nil + } + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-1", + Payload: map[string]any{"a": 1}, + } + + first, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.False(t, first.Replayed) + + second, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.True(t, second.Replayed) + require.Equal(t, 1, execCount, "second request should replay without executing business logic") + + metrics := GetIdempotencyMetricsSnapshot() + require.Equal(t, uint64(1), metrics.ClaimTotal) + require.Equal(t, uint64(1), metrics.ReplayTotal) +} + +func TestIdempotencyCoordinator_ReclaimExpiredSucceededRecord(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope.expired", + Method: "POST", + Route: "/test/expired", + ActorScope: "user:99", + RequireKey: true, + IdempotencyKey: "expired-case", + Payload: map[string]any{"k": "v"}, + } + + execCount := 0 + exec := func(ctx context.Context) (any, error) { + execCount++ + return map[string]any{"count": execCount}, nil + } + + first, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.NotNil(t, first) + require.False(t, first.Replayed) + require.Equal(t, 1, execCount) + + keyHash := HashIdempotencyKey(opts.IdempotencyKey) + repo.mu.Lock() + existing := repo.data[repo.key(opts.Scope, keyHash)] + require.NotNil(t, existing) + existing.ExpiresAt = time.Now().Add(-time.Second) + repo.mu.Unlock() + + second, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.NotNil(t, second) + require.False(t, second.Replayed, "expired record should be reclaimed and execute business logic again") + require.Equal(t, 2, execCount) + + third, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.NotNil(t, third) + require.True(t, third.Replayed) + payload, ok := third.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, float64(2), payload["count"]) + + metrics := GetIdempotencyMetricsSnapshot() + require.GreaterOrEqual(t, metrics.ClaimTotal, uint64(2)) + require.GreaterOrEqual(t, metrics.ReplayTotal, uint64(1)) +} + +func TestIdempotencyCoordinator_SameKeyDifferentPayloadConflict(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + coordinator := NewIdempotencyCoordinator(repo, cfg) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-2", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.NoError(t, err) + + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-2", + Payload: map[string]any{"a": 2}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyConflict)) + + metrics := GetIdempotencyMetricsSnapshot() + require.Equal(t, uint64(1), metrics.ConflictTotal) +} + +func TestIdempotencyCoordinator_BackoffAfterRetryableFailure(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + cfg.FailedRetryBackoff = 2 * time.Second + coordinator := NewIdempotencyCoordinator(repo, cfg) + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-3", + Payload: map[string]any{"a": 1}, + } + + _, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + return nil, infraerrors.InternalServer("UPSTREAM_ERROR", "upstream error") + }) + require.Error(t, err) + + _, err = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyRetryBackoff)) + require.Greater(t, RetryAfterSecondsFromError(err), 0) + + metrics := GetIdempotencyMetricsSnapshot() + require.GreaterOrEqual(t, metrics.RetryBackoffTotal, uint64(2)) + require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1)) + require.GreaterOrEqual(t, metrics.ProcessingDurationCount, uint64(1)) +} + +func TestIdempotencyCoordinator_ConcurrentSameKeySingleSideEffect(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + cfg.ProcessingTimeout = 2 * time.Second + coordinator := NewIdempotencyCoordinator(repo, cfg) + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope.concurrent", + Method: "POST", + Route: "/test/concurrent", + ActorScope: "user:7", + RequireKey: true, + IdempotencyKey: "concurrent-case", + Payload: map[string]any{"v": 1}, + } + + var execCount int32 + var wg sync.WaitGroup + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + atomic.AddInt32(&execCount, 1) + time.Sleep(80 * time.Millisecond) + return map[string]any{"ok": true}, nil + }) + }() + } + wg.Wait() + + replayed, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + atomic.AddInt32(&execCount, 1) + return map[string]any{"ok": true}, nil + }) + require.NoError(t, err) + require.True(t, replayed.Replayed) + require.Equal(t, int32(1), atomic.LoadInt32(&execCount), "concurrent same-key requests should execute business side-effect once") + + metrics := GetIdempotencyMetricsSnapshot() + require.Equal(t, uint64(1), metrics.ClaimTotal) + require.Equal(t, uint64(1), metrics.ReplayTotal) + require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1)) +} + +type failingIdempotencyRepo struct{} + +func (failingIdempotencyRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return false, errors.New("store unavailable") +} +func (failingIdempotencyRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return nil, errors.New("store unavailable") +} +func (failingIdempotencyRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (failingIdempotencyRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (failingIdempotencyRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return errors.New("store unavailable") +} +func (failingIdempotencyRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return errors.New("store unavailable") +} +func (failingIdempotencyRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, errors.New("store unavailable") +} + +func TestIdempotencyCoordinator_StoreUnavailableMetrics(t *testing.T) { + resetIdempotencyMetricsForTest() + coordinator := NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig()) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope.unavailable", + Method: "POST", + Route: "/test/unavailable", + ActorScope: "admin:1", + RequireKey: true, + IdempotencyKey: "case-unavailable", + Payload: map[string]any{"v": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + require.GreaterOrEqual(t, GetIdempotencyMetricsSnapshot().StoreUnavailableTotal, uint64(1)) +} + +func TestDefaultIdempotencyCoordinatorAndTTLs(t *testing.T) { + SetDefaultIdempotencyCoordinator(nil) + require.Nil(t, DefaultIdempotencyCoordinator()) + require.Equal(t, DefaultIdempotencyConfig().DefaultTTL, DefaultWriteIdempotencyTTL()) + require.Equal(t, DefaultIdempotencyConfig().SystemOperationTTL, DefaultSystemOperationIdempotencyTTL()) + + coordinator := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{ + DefaultTTL: 2 * time.Hour, + SystemOperationTTL: 15 * time.Minute, + ProcessingTimeout: 10 * time.Second, + FailedRetryBackoff: 3 * time.Second, + ObserveOnly: false, + }) + SetDefaultIdempotencyCoordinator(coordinator) + t.Cleanup(func() { + SetDefaultIdempotencyCoordinator(nil) + }) + + require.Same(t, coordinator, DefaultIdempotencyCoordinator()) + require.Equal(t, 2*time.Hour, DefaultWriteIdempotencyTTL()) + require.Equal(t, 15*time.Minute, DefaultSystemOperationIdempotencyTTL()) +} + +func TestNormalizeIdempotencyKeyAndFingerprint(t *testing.T) { + key, err := NormalizeIdempotencyKey(" abc-123 ") + require.NoError(t, err) + require.Equal(t, "abc-123", key) + + key, err = NormalizeIdempotencyKey("") + require.NoError(t, err) + require.Equal(t, "", key) + + _, err = NormalizeIdempotencyKey(string(make([]byte, 129))) + require.Error(t, err) + + _, err = NormalizeIdempotencyKey("bad\nkey") + require.Error(t, err) + + fp1, err := BuildIdempotencyFingerprint("", "", "", map[string]any{"a": 1}) + require.NoError(t, err) + require.NotEmpty(t, fp1) + fp2, err := BuildIdempotencyFingerprint("POST", "/", "anonymous", map[string]any{"a": 1}) + require.NoError(t, err) + require.Equal(t, fp1, fp2) + + _, err = BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"bad": make(chan int)}) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyInvalidPayload), infraerrors.Code(err)) +} + +func TestRetryAfterSecondsFromErrorBranches(t *testing.T) { + require.Equal(t, 0, RetryAfterSecondsFromError(nil)) + require.Equal(t, 0, RetryAfterSecondsFromError(errors.New("plain"))) + + err := ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "12"}) + require.Equal(t, 12, RetryAfterSecondsFromError(err)) + + err = ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "bad"}) + require.Equal(t, 0, RetryAfterSecondsFromError(err)) +} + +func TestIdempotencyCoordinator_ExecuteNilExecutorAndNoKeyPassThrough(t *testing.T) { + repo := newInMemoryIdempotencyRepo() + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Payload: map[string]any{"a": 1}, + }, nil) + require.Error(t, err) + require.Equal(t, "IDEMPOTENCY_EXECUTOR_NIL", infraerrors.Reason(err)) + + called := 0 + result, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + RequireKey: true, + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + called++ + return map[string]any{"ok": true}, nil + }) + require.NoError(t, err) + require.Equal(t, 1, called) + require.NotNil(t, result) + require.False(t, result.Replayed) +} + +type noIDOwnerRepo struct{} + +func (noIDOwnerRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return true, nil +} +func (noIDOwnerRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return nil, nil +} +func (noIDOwnerRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, nil +} +func (noIDOwnerRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, nil +} +func (noIDOwnerRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { return nil } +func (noIDOwnerRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return nil +} +func (noIDOwnerRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { return 0, nil } + +func TestIdempotencyCoordinator_RepoNilScopeRequiredAndRecordIDMissing(t *testing.T) { + cfg := DefaultIdempotencyConfig() + coordinator := NewIdempotencyCoordinator(nil, cfg) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + coordinator = NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), cfg) + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + IdempotencyKey: "k2", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, "IDEMPOTENCY_SCOPE_REQUIRED", infraerrors.Reason(err)) + + coordinator = NewIdempotencyCoordinator(noIDOwnerRepo{}, cfg) + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-no-id", + IdempotencyKey: "k3", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) +} + +type conflictBranchRepo struct { + existing *IdempotencyRecord + tryReclaimErr error + tryReclaimOK bool +} + +func (r *conflictBranchRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return false, nil +} +func (r *conflictBranchRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return cloneRecord(r.existing), nil +} +func (r *conflictBranchRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + if r.tryReclaimErr != nil { + return false, r.tryReclaimErr + } + return r.tryReclaimOK, nil +} +func (r *conflictBranchRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, nil +} +func (r *conflictBranchRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return nil +} +func (r *conflictBranchRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return nil +} +func (r *conflictBranchRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, nil +} + +func TestIdempotencyCoordinator_ConflictBranchesAndDecodeError(t *testing.T) { + now := time.Now() + fp, err := BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"a": 1}) + require.NoError(t, err) + badBody := "{bad-json" + repo := &conflictBranchRepo{ + existing: &IdempotencyRecord{ + ID: 1, + Scope: "scope", + IdempotencyKeyHash: HashIdempotencyKey("k"), + RequestFingerprint: fp, + Status: IdempotencyStatusSucceeded, + ResponseBody: &badBody, + ExpiresAt: now.Add(time.Hour), + }, + } + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.existing = &IdempotencyRecord{ + ID: 2, + Scope: "scope", + IdempotencyKeyHash: HashIdempotencyKey("k"), + RequestFingerprint: fp, + Status: "unknown", + ExpiresAt: now.Add(time.Hour), + } + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyKeyConflict), infraerrors.Code(err)) + + repo.existing = &IdempotencyRecord{ + ID: 3, + Scope: "scope", + IdempotencyKeyHash: HashIdempotencyKey("k"), + RequestFingerprint: fp, + Status: IdempotencyStatusFailedRetryable, + LockedUntil: ptrTime(now.Add(-time.Second)), + ExpiresAt: now.Add(time.Hour), + } + repo.tryReclaimErr = errors.New("reclaim down") + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.tryReclaimErr = nil + repo.tryReclaimOK = false + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyInProgress), infraerrors.Code(err)) +} + +type markBehaviorRepo struct { + inMemoryIdempotencyRepo + failMarkSucceeded bool + failMarkFailed bool +} + +func (r *markBehaviorRepo) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + if r.failMarkSucceeded { + return errors.New("mark succeeded failed") + } + return r.inMemoryIdempotencyRepo.MarkSucceeded(ctx, id, responseStatus, responseBody, expiresAt) +} + +func (r *markBehaviorRepo) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + if r.failMarkFailed { + return errors.New("mark failed retryable failed") + } + return r.inMemoryIdempotencyRepo.MarkFailedRetryable(ctx, id, errorReason, lockedUntil, expiresAt) +} + +func TestIdempotencyCoordinator_MarkAndMarshalBranches(t *testing.T) { + repo := &markBehaviorRepo{inMemoryIdempotencyRepo: *newInMemoryIdempotencyRepo()} + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + + repo.failMarkSucceeded = true + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-success", + IdempotencyKey: "k1", + Method: "POST", + Route: "/ok", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.failMarkSucceeded = false + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-marshal", + IdempotencyKey: "k2", + Method: "POST", + Route: "/bad", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"bad": make(chan int)}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.failMarkFailed = true + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-fail", + IdempotencyKey: "k3", + Method: "POST", + Route: "/fail", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return nil, errors.New("plain failure") + }) + require.Error(t, err) + require.Equal(t, "plain failure", err.Error()) +} + +func TestIdempotencyCoordinator_HelperBranches(t *testing.T) { + c := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{ + DefaultTTL: time.Hour, + SystemOperationTTL: time.Hour, + ProcessingTimeout: time.Second, + FailedRetryBackoff: time.Second, + MaxStoredResponseLen: 12, + ObserveOnly: false, + }) + + // conflictWithRetryAfter without locked_until should return base error. + base := ErrIdempotencyInProgress + err := c.conflictWithRetryAfter(base, nil, time.Now()) + require.Equal(t, infraerrors.Code(base), infraerrors.Code(err)) + + // marshalStoredResponse should truncate. + body, err := c.marshalStoredResponse(map[string]any{"long": "abcdefghijklmnopqrstuvwxyz"}) + require.NoError(t, err) + require.Contains(t, body, "...(truncated)") + + // decodeStoredResponse empty and invalid json. + out, err := c.decodeStoredResponse(nil) + require.NoError(t, err) + _, ok := out.(map[string]any) + require.True(t, ok) + + invalid := "{invalid" + _, err = c.decodeStoredResponse(&invalid) + require.Error(t, err) +} diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index 261da0ef..f3130c91 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -7,13 +7,14 @@ import ( "encoding/hex" "encoding/json" "fmt" - "log" "log/slog" "net/http" "regexp" "strconv" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // 预编译正则表达式(避免每次调用重新编译) @@ -45,6 +46,7 @@ type Fingerprint struct { StainlessArch string StainlessRuntime string StainlessRuntimeVersion string + UpdatedAt int64 `json:",omitempty"` // Unix timestamp,用于判断是否需要续期TTL } // IdentityCache defines cache operations for identity service @@ -77,14 +79,26 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID // 尝试从缓存获取指纹 cached, err := s.cache.GetFingerprint(ctx, accountID) if err == nil && cached != nil { + needWrite := false + // 检查客户端的user-agent是否是更新版本 clientUA := headers.Get("User-Agent") if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) { - // 更新user-agent - cached.UserAgent = clientUA - // 保存更新后的指纹 - _ = s.cache.SetFingerprint(ctx, accountID, cached) - log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA) + // 版本升级:merge 语义 — 仅更新请求中实际携带的字段,保留缓存值 + // 避免缺失的头被硬编码默认值覆盖(如新 CLI 版本 + 旧 SDK 默认值的不一致) + mergeHeadersIntoFingerprint(cached, headers) + needWrite = true + logger.LegacyPrintf("service.identity", "Updated fingerprint for account %d: %s (merge update)", accountID, clientUA) + } else if time.Since(time.Unix(cached.UpdatedAt, 0)) > 24*time.Hour { + // 距上次写入超过24小时,续期TTL + needWrite = true + } + + if needWrite { + cached.UpdatedAt = time.Now().Unix() + if err := s.cache.SetFingerprint(ctx, accountID, cached); err != nil { + logger.LegacyPrintf("service.identity", "Warning: failed to refresh fingerprint for account %d: %v", accountID, err) + } } return cached, nil } @@ -94,13 +108,14 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID // 生成随机ClientID fp.ClientID = generateClientID() + fp.UpdatedAt = time.Now().Unix() - // 保存到缓存(永不过期) + // 保存到缓存(7天TTL,每24小时自动续期) if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil { - log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err) + logger.LegacyPrintf("service.identity", "Warning: failed to cache fingerprint for account %d: %v", accountID, err) } - log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID) + logger.LegacyPrintf("service.identity", "Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID) return fp, nil } @@ -126,6 +141,31 @@ func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fin return fp } +// mergeHeadersIntoFingerprint 将请求头中实际存在的字段合并到现有指纹中(用于版本升级场景) +// 关键语义:请求中有的字段 → 用新值覆盖;缺失的头 → 保留缓存中的已有值 +// 与 createFingerprintFromHeaders 的区别:后者用于首次创建,缺失头回退到 defaultFingerprint; +// 本函数用于升级更新,缺失头保留缓存值,避免将已知的真实值退化为硬编码默认值 +func mergeHeadersIntoFingerprint(fp *Fingerprint, headers http.Header) { + // User-Agent:版本升级的触发条件,一定存在 + if ua := headers.Get("User-Agent"); ua != "" { + fp.UserAgent = ua + } + // X-Stainless-* 头:仅在请求中实际携带时才更新,否则保留缓存值 + mergeHeader(headers, "X-Stainless-Lang", &fp.StainlessLang) + mergeHeader(headers, "X-Stainless-Package-Version", &fp.StainlessPackageVersion) + mergeHeader(headers, "X-Stainless-OS", &fp.StainlessOS) + mergeHeader(headers, "X-Stainless-Arch", &fp.StainlessArch) + mergeHeader(headers, "X-Stainless-Runtime", &fp.StainlessRuntime) + mergeHeader(headers, "X-Stainless-Runtime-Version", &fp.StainlessRuntimeVersion) +} + +// mergeHeader 如果请求头中存在该字段则更新目标值,否则保留原值 +func mergeHeader(headers http.Header, key string, target *string) { + if v := headers.Get(key); v != "" { + *target = v + } +} + // getHeaderOrDefault 获取header值,如果不存在则返回默认值 func getHeaderOrDefault(headers http.Header, key, defaultValue string) string { if v := headers.Get(key); v != "" { @@ -277,19 +317,19 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b // 获取或生成固定的伪装 session ID maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID) if err != nil { - log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err) + logger.LegacyPrintf("service.identity", "Warning: failed to get masked session ID for account %d: %v", account.ID, err) return newBody, nil } if maskedSessionID == "" { // 首次或已过期,生成新的伪装 session ID maskedSessionID = generateRandomUUID() - log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID) + logger.LegacyPrintf("service.identity", "Generated new masked session ID for account %d: %s", account.ID, maskedSessionID) } // 刷新 TTL(每次请求都刷新,保持 15 分钟有效期) if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil { - log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err) + logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err) } // 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容 @@ -335,7 +375,7 @@ func generateClientID() string { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { // 极罕见的情况,使用时间戳+固定值作为fallback - log.Printf("Warning: crypto/rand.Read failed: %v, using fallback", err) + logger.LegacyPrintf("service.identity", "Warning: crypto/rand.Read failed: %v, using fallback", err) // 使用SHA256(当前纳秒时间)作为fallback h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano()))) return hex.EncodeToString(h[:]) @@ -370,8 +410,25 @@ func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) { return major, minor, patch, true } +// extractProduct 提取 User-Agent 中 "/" 前的产品名 +// 例如:claude-cli/2.1.22 (external, cli) -> "claude-cli" +func extractProduct(ua string) string { + if idx := strings.Index(ua, "/"); idx > 0 { + return strings.ToLower(ua[:idx]) + } + return "" +} + // isNewerVersion 比较版本号,判断newUA是否比cachedUA更新 +// 要求产品名一致(防止浏览器 UA 如 Mozilla/5.0 误判为更新版本) func isNewerVersion(newUA, cachedUA string) bool { + // 校验产品名一致性 + newProduct := extractProduct(newUA) + cachedProduct := extractProduct(cachedUA) + if newProduct == "" || cachedProduct == "" || newProduct != cachedProduct { + return false + } + newMajor, newMinor, newPatch, newOk := parseUserAgentVersion(newUA) cachedMajor, cachedMinor, cachedPatch, cachedOk := parseUserAgentVersion(cachedUA) diff --git a/backend/internal/service/model_rate_limit.go b/backend/internal/service/model_rate_limit.go index ff4b5977..c45615cc 100644 --- a/backend/internal/service/model_rate_limit.go +++ b/backend/internal/service/model_rate_limit.go @@ -4,8 +4,6 @@ import ( "context" "strings" "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" ) const modelRateLimitsKey = "model_rate_limits" @@ -73,7 +71,7 @@ func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requ return "" } // thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking) - if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + if enabled, ok := ThinkingEnabledFromContext(ctx); ok { modelKey = applyThinkingModelSuffix(modelKey, enabled) } return modelKey diff --git a/backend/internal/service/model_rate_limit_test.go b/backend/internal/service/model_rate_limit_test.go index a51e6909..b79b9688 100644 --- a/backend/internal/service/model_rate_limit_test.go +++ b/backend/internal/service/model_rate_limit_test.go @@ -318,110 +318,6 @@ func TestGetModelRateLimitRemainingTime(t *testing.T) { } } -func TestGetQuotaScopeRateLimitRemainingTime(t *testing.T) { - now := time.Now() - future10m := now.Add(10 * time.Minute).Format(time.RFC3339) - past := now.Add(-10 * time.Minute).Format(time.RFC3339) - - tests := []struct { - name string - account *Account - requestedModel string - minExpected time.Duration - maxExpected time.Duration - }{ - { - name: "nil account", - account: nil, - requestedModel: "claude-sonnet-4-5", - minExpected: 0, - maxExpected: 0, - }, - { - name: "non-antigravity platform", - account: &Account{ - Platform: PlatformAnthropic, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future10m, - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 0, - maxExpected: 0, - }, - { - name: "claude scope rate limited", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future10m, - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 9 * time.Minute, - maxExpected: 11 * time.Minute, - }, - { - name: "gemini_text scope rate limited", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "gemini_text": map[string]any{ - "rate_limit_reset_at": future10m, - }, - }, - }, - }, - requestedModel: "gemini-3-flash", - minExpected: 9 * time.Minute, - maxExpected: 11 * time.Minute, - }, - { - name: "expired scope rate limit", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": past, - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 0, - maxExpected: 0, - }, - { - name: "unsupported model", - account: &Account{ - Platform: PlatformAntigravity, - }, - requestedModel: "gpt-4", - minExpected: 0, - maxExpected: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.account.GetQuotaScopeRateLimitRemainingTime(tt.requestedModel) - if result < tt.minExpected || result > tt.maxExpected { - t.Errorf("GetQuotaScopeRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected) - } - }) - } -} - func TestGetRateLimitRemainingTime(t *testing.T) { now := time.Now() future15m := now.Add(15 * time.Minute).Format(time.RFC3339) @@ -442,45 +338,19 @@ func TestGetRateLimitRemainingTime(t *testing.T) { maxExpected: 0, }, { - name: "model remaining > scope remaining - returns model", + name: "model rate limited - 15 minutes", account: &Account{ Platform: PlatformAntigravity, Extra: map[string]any{ modelRateLimitsKey: map[string]any{ "claude-sonnet-4-5": map[string]any{ - "rate_limit_reset_at": future15m, // 15 分钟 - }, - }, - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future5m, // 5 分钟 + "rate_limit_reset_at": future15m, }, }, }, }, requestedModel: "claude-sonnet-4-5", - minExpected: 14 * time.Minute, // 应返回较大的 15 分钟 - maxExpected: 16 * time.Minute, - }, - { - name: "scope remaining > model remaining - returns scope", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - modelRateLimitsKey: map[string]any{ - "claude-sonnet-4-5": map[string]any{ - "rate_limit_reset_at": future5m, // 5 分钟 - }, - }, - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future15m, // 15 分钟 - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 14 * time.Minute, // 应返回较大的 15 分钟 + minExpected: 14 * time.Minute, maxExpected: 16 * time.Minute, }, { @@ -499,22 +369,6 @@ func TestGetRateLimitRemainingTime(t *testing.T) { minExpected: 4 * time.Minute, maxExpected: 6 * time.Minute, }, - { - name: "only scope rate limited", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future5m, - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 4 * time.Minute, - maxExpected: 6 * time.Minute, - }, { name: "neither rate limited", account: &Account{ diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index 15543080..0931f9ce 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -12,8 +12,9 @@ import ( // OpenAIOAuthClient interface for OpenAI OAuth operations type OpenAIOAuthClient interface { - ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) + ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) + RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) } // ClaudeOAuthClient handles HTTP requests for Claude OAuth flows @@ -217,7 +218,7 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) ( // Ensure org_uuid is set (from step 1 if not from token response) if tokenInfo.OrgUUID == "" && orgUUID != "" { tokenInfo.OrgUUID = orgUUID - log.Printf("[OAuth] Set org_uuid from cookie auth: %s", orgUUID) + log.Printf("[OAuth] Set org_uuid from cookie auth") } return tokenInfo, nil @@ -251,16 +252,16 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" { tokenInfo.OrgUUID = tokenResp.Organization.UUID - log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID) + log.Printf("[OAuth] Got org_uuid") } if tokenResp.Account != nil { if tokenResp.Account.UUID != "" { tokenInfo.AccountUUID = tokenResp.Account.UUID - log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID) + log.Printf("[OAuth] Got account_uuid") } if tokenResp.Account.EmailAddress != "" { tokenInfo.EmailAddress = tokenResp.Account.EmailAddress - log.Printf("[OAuth] Got email_address: %s", tokenInfo.EmailAddress) + log.Printf("[OAuth] Got email_address") } } diff --git a/backend/internal/service/oauth_service_test.go b/backend/internal/service/oauth_service_test.go new file mode 100644 index 00000000..78f39dc5 --- /dev/null +++ b/backend/internal/service/oauth_service_test.go @@ -0,0 +1,607 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +// --- mock: ClaudeOAuthClient --- + +type mockClaudeOAuthClient struct { + getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error) + getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) + exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) + refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) +} + +func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { + if m.getOrgUUIDFunc != nil { + return m.getOrgUUIDFunc(ctx, sessionKey, proxyURL) + } + panic("GetOrganizationUUID not implemented") +} + +func (m *mockClaudeOAuthClient) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { + if m.getAuthCodeFunc != nil { + return m.getAuthCodeFunc(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL) + } + panic("GetAuthorizationCode not implemented") +} + +func (m *mockClaudeOAuthClient) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + if m.exchangeCodeFunc != nil { + return m.exchangeCodeFunc(ctx, code, codeVerifier, state, proxyURL, isSetupToken) + } + panic("ExchangeCodeForToken not implemented") +} + +func (m *mockClaudeOAuthClient) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if m.refreshTokenFunc != nil { + return m.refreshTokenFunc(ctx, refreshToken, proxyURL) + } + panic("RefreshToken not implemented") +} + +// --- mock: ProxyRepository (最小实现,仅覆盖 OAuthService 依赖的方法) --- + +type mockProxyRepoForOAuth struct { + getByIDFunc func(ctx context.Context, id int64) (*Proxy, error) +} + +func (m *mockProxyRepoForOAuth) Create(ctx context.Context, proxy *Proxy) error { + panic("Create not implemented") +} +func (m *mockProxyRepoForOAuth) GetByID(ctx context.Context, id int64) (*Proxy, error) { + if m.getByIDFunc != nil { + return m.getByIDFunc(ctx, id) + } + return nil, fmt.Errorf("proxy not found") +} +func (m *mockProxyRepoForOAuth) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) { + panic("ListByIDs not implemented") +} +func (m *mockProxyRepoForOAuth) Update(ctx context.Context, proxy *Proxy) error { + panic("Update not implemented") +} +func (m *mockProxyRepoForOAuth) Delete(ctx context.Context, id int64) error { + panic("Delete not implemented") +} +func (m *mockProxyRepoForOAuth) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) { + panic("List not implemented") +} +func (m *mockProxyRepoForOAuth) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) { + panic("ListWithFilters not implemented") +} +func (m *mockProxyRepoForOAuth) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) { + panic("ListWithFiltersAndAccountCount not implemented") +} +func (m *mockProxyRepoForOAuth) ListActive(ctx context.Context) ([]Proxy, error) { + panic("ListActive not implemented") +} +func (m *mockProxyRepoForOAuth) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { + panic("ListActiveWithAccountCount not implemented") +} +func (m *mockProxyRepoForOAuth) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { + panic("ExistsByHostPortAuth not implemented") +} +func (m *mockProxyRepoForOAuth) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { + panic("CountAccountsByProxyID not implemented") +} +func (m *mockProxyRepoForOAuth) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) { + panic("ListAccountSummariesByProxyID not implemented") +} + +// ===================== +// 测试用例 +// ===================== + +func TestNewOAuthService(t *testing.T) { + t.Parallel() + + proxyRepo := &mockProxyRepoForOAuth{} + client := &mockClaudeOAuthClient{} + svc := NewOAuthService(proxyRepo, client) + + if svc == nil { + t.Fatal("NewOAuthService 返回 nil") + } + if svc.proxyRepo != proxyRepo { + t.Fatal("proxyRepo 未正确设置") + } + if svc.oauthClient != client { + t.Fatal("oauthClient 未正确设置") + } + if svc.sessionStore == nil { + t.Fatal("sessionStore 应被自动初始化") + } + + // 清理 + svc.Stop() +} + +func TestOAuthService_GenerateAuthURL(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + result, err := svc.GenerateAuthURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateAuthURL 返回错误: %v", err) + } + if result == nil { + t.Fatal("GenerateAuthURL 返回 nil") + } + if result.AuthURL == "" { + t.Fatal("AuthURL 为空") + } + if result.SessionID == "" { + t.Fatal("SessionID 为空") + } + + // 验证 session 已存储 + session, ok := svc.sessionStore.Get(result.SessionID) + if !ok { + t.Fatal("session 未在 sessionStore 中找到") + } + if session.Scope != oauth.ScopeOAuth { + t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeOAuth) + } +} + +func TestOAuthService_GenerateAuthURL_WithProxy(t *testing.T) { + t.Parallel() + + proxyRepo := &mockProxyRepoForOAuth{ + getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) { + return &Proxy{ + ID: 1, + Protocol: "http", + Host: "proxy.example.com", + Port: 8080, + }, nil + }, + } + svc := NewOAuthService(proxyRepo, &mockClaudeOAuthClient{}) + defer svc.Stop() + + proxyID := int64(1) + result, err := svc.GenerateAuthURL(context.Background(), &proxyID) + if err != nil { + t.Fatalf("GenerateAuthURL 返回错误: %v", err) + } + + session, ok := svc.sessionStore.Get(result.SessionID) + if !ok { + t.Fatal("session 未在 sessionStore 中找到") + } + if session.ProxyURL != "http://proxy.example.com:8080" { + t.Fatalf("ProxyURL 不匹配: got=%q", session.ProxyURL) + } +} + +func TestOAuthService_GenerateSetupTokenURL(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + result, err := svc.GenerateSetupTokenURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err) + } + if result == nil { + t.Fatal("GenerateSetupTokenURL 返回 nil") + } + + // 验证 scope 是 inference + session, ok := svc.sessionStore.Get(result.SessionID) + if !ok { + t.Fatal("session 未在 sessionStore 中找到") + } + if session.Scope != oauth.ScopeInference { + t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeInference) + } +} + +func TestOAuthService_ExchangeCode_SessionNotFound(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + _, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: "nonexistent-session", + Code: "test-code", + }) + if err == nil { + t.Fatal("ExchangeCode 应返回错误(session 不存在)") + } + if err.Error() != "session not found or expired" { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestOAuthService_ExchangeCode_Success(t *testing.T) { + t.Parallel() + + exchangeCalled := false + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + exchangeCalled = true + if code != "auth-code-123" { + t.Errorf("code 不匹配: got=%q", code) + } + if isSetupToken { + t.Error("isSetupToken 应为 false(ScopeOAuth)") + } + return &oauth.TokenResponse{ + AccessToken: "access-token-abc", + TokenType: "Bearer", + ExpiresIn: 3600, + RefreshToken: "refresh-token-xyz", + Scope: oauth.ScopeOAuth, + Organization: &oauth.OrgInfo{UUID: "org-uuid-111"}, + Account: &oauth.AccountInfo{UUID: "acc-uuid-222", EmailAddress: "test@example.com"}, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + // 先生成 URL 以创建 session + result, err := svc.GenerateAuthURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateAuthURL 返回错误: %v", err) + } + + // 交换 code + tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "auth-code-123", + }) + if err != nil { + t.Fatalf("ExchangeCode 返回错误: %v", err) + } + + if !exchangeCalled { + t.Fatal("ExchangeCodeForToken 未被调用") + } + if tokenInfo.AccessToken != "access-token-abc" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } + if tokenInfo.TokenType != "Bearer" { + t.Fatalf("TokenType 不匹配: got=%q", tokenInfo.TokenType) + } + if tokenInfo.RefreshToken != "refresh-token-xyz" { + t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken) + } + if tokenInfo.OrgUUID != "org-uuid-111" { + t.Fatalf("OrgUUID 不匹配: got=%q", tokenInfo.OrgUUID) + } + if tokenInfo.AccountUUID != "acc-uuid-222" { + t.Fatalf("AccountUUID 不匹配: got=%q", tokenInfo.AccountUUID) + } + if tokenInfo.EmailAddress != "test@example.com" { + t.Fatalf("EmailAddress 不匹配: got=%q", tokenInfo.EmailAddress) + } + if tokenInfo.ExpiresIn != 3600 { + t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn) + } + if tokenInfo.ExpiresAt == 0 { + t.Fatal("ExpiresAt 不应为 0") + } + + // 验证 session 已被删除 + _, ok := svc.sessionStore.Get(result.SessionID) + if ok { + t.Fatal("session 应在交换成功后被删除") + } +} + +func TestOAuthService_ExchangeCode_SetupToken(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + if !isSetupToken { + t.Error("isSetupToken 应为 true(ScopeInference)") + } + return &oauth.TokenResponse{ + AccessToken: "setup-token", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: oauth.ScopeInference, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + // 使用 SetupToken URL(inference scope) + result, err := svc.GenerateSetupTokenURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err) + } + + tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "setup-code", + }) + if err != nil { + t.Fatalf("ExchangeCode 返回错误: %v", err) + } + if tokenInfo.AccessToken != "setup-token" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } +} + +func TestOAuthService_ExchangeCode_ClientError(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + return nil, fmt.Errorf("upstream error: invalid code") + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + result, _ := svc.GenerateAuthURL(context.Background(), nil) + _, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "bad-code", + }) + if err == nil { + t.Fatal("ExchangeCode 应返回错误") + } + if err.Error() != "upstream error: invalid code" { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestOAuthService_RefreshToken(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if refreshToken != "my-refresh-token" { + t.Errorf("refreshToken 不匹配: got=%q", refreshToken) + } + if proxyURL != "" { + t.Errorf("proxyURL 应为空: got=%q", proxyURL) + } + return &oauth.TokenResponse{ + AccessToken: "new-access-token", + TokenType: "Bearer", + ExpiresIn: 7200, + RefreshToken: "new-refresh-token", + Scope: oauth.ScopeOAuth, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + tokenInfo, err := svc.RefreshToken(context.Background(), "my-refresh-token", "") + if err != nil { + t.Fatalf("RefreshToken 返回错误: %v", err) + } + if tokenInfo.AccessToken != "new-access-token" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } + if tokenInfo.RefreshToken != "new-refresh-token" { + t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken) + } + if tokenInfo.ExpiresIn != 7200 { + t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn) + } + if tokenInfo.ExpiresAt == 0 { + t.Fatal("ExpiresAt 不应为 0") + } +} + +func TestOAuthService_RefreshToken_Error(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + return nil, fmt.Errorf("invalid_grant: token expired") + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + _, err := svc.RefreshToken(context.Background(), "expired-token", "") + if err == nil { + t.Fatal("RefreshToken 应返回错误") + } +} + +func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + // 无 refresh_token 的账号 + account := &Account{ + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "some-token", + }, + } + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("RefreshAccountToken 应返回错误(无 refresh_token)") + } + if err.Error() != "no refresh token available" { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + account := &Account{ + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "some-token", + "refresh_token": "", + }, + } + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("RefreshAccountToken 应返回错误(refresh_token 为空)") + } +} + +func TestOAuthService_RefreshAccountToken_Success(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if refreshToken != "account-refresh-token" { + t.Errorf("refreshToken 不匹配: got=%q", refreshToken) + } + return &oauth.TokenResponse{ + AccessToken: "refreshed-access", + TokenType: "Bearer", + ExpiresIn: 3600, + RefreshToken: "new-refresh", + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + account := &Account{ + ID: 3, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-access", + "refresh_token": "account-refresh-token", + }, + } + + tokenInfo, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if tokenInfo.AccessToken != "refreshed-access" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } +} + +func TestOAuthService_RefreshAccountToken_WithProxy(t *testing.T) { + t.Parallel() + + proxyRepo := &mockProxyRepoForOAuth{ + getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) { + return &Proxy{ + Protocol: "socks5", + Host: "socks.example.com", + Port: 1080, + Username: "user", + Password: "pass", + }, nil + }, + } + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if proxyURL != "socks5://user:pass@socks.example.com:1080" { + t.Errorf("proxyURL 不匹配: got=%q", proxyURL) + } + return &oauth.TokenResponse{ + AccessToken: "refreshed", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewOAuthService(proxyRepo, client) + defer svc.Stop() + + proxyID := int64(10) + account := &Account{ + ID: 4, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + ProxyID: &proxyID, + Credentials: map[string]any{ + "refresh_token": "rt-with-proxy", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } +} + +func TestOAuthService_ExchangeCode_NilOrg(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + return &oauth.TokenResponse{ + AccessToken: "token-no-org", + TokenType: "Bearer", + ExpiresIn: 3600, + Organization: nil, + Account: nil, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + result, _ := svc.GenerateAuthURL(context.Background(), nil) + tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "code", + }) + if err != nil { + t.Fatalf("ExchangeCode 返回错误: %v", err) + } + if tokenInfo.OrgUUID != "" { + t.Fatalf("OrgUUID 应为空: got=%q", tokenInfo.OrgUUID) + } + if tokenInfo.AccountUUID != "" { + t.Fatalf("AccountUUID 应为空: got=%q", tokenInfo.AccountUUID) + } +} + +func TestOAuthService_Stop_NoPanic(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + + // 调用 Stop 不应 panic + svc.Stop() + + // 多次调用也不应 panic + svc.Stop() +} diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go new file mode 100644 index 00000000..99013ce5 --- /dev/null +++ b/backend/internal/service/openai_account_scheduler.go @@ -0,0 +1,909 @@ +package service + +import ( + "container/heap" + "context" + "errors" + "hash/fnv" + "math" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +const ( + openAIAccountScheduleLayerPreviousResponse = "previous_response_id" + openAIAccountScheduleLayerSessionSticky = "session_hash" + openAIAccountScheduleLayerLoadBalance = "load_balance" +) + +type OpenAIAccountScheduleRequest struct { + GroupID *int64 + SessionHash string + StickyAccountID int64 + PreviousResponseID string + RequestedModel string + RequiredTransport OpenAIUpstreamTransport + ExcludedIDs map[int64]struct{} +} + +type OpenAIAccountScheduleDecision struct { + Layer string + StickyPreviousHit bool + StickySessionHit bool + CandidateCount int + TopK int + LatencyMs int64 + LoadSkew float64 + SelectedAccountID int64 + SelectedAccountType string +} + +type OpenAIAccountSchedulerMetricsSnapshot struct { + SelectTotal int64 + StickyPreviousHitTotal int64 + StickySessionHitTotal int64 + LoadBalanceSelectTotal int64 + AccountSwitchTotal int64 + SchedulerLatencyMsTotal int64 + SchedulerLatencyMsAvg float64 + StickyHitRatio float64 + AccountSwitchRate float64 + LoadSkewAvg float64 + RuntimeStatsAccountCount int +} + +type OpenAIAccountScheduler interface { + Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) + ReportResult(accountID int64, success bool, firstTokenMs *int) + ReportSwitch() + SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot +} + +type openAIAccountSchedulerMetrics struct { + selectTotal atomic.Int64 + stickyPreviousHitTotal atomic.Int64 + stickySessionHitTotal atomic.Int64 + loadBalanceSelectTotal atomic.Int64 + accountSwitchTotal atomic.Int64 + latencyMsTotal atomic.Int64 + loadSkewMilliTotal atomic.Int64 +} + +func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) { + if m == nil { + return + } + m.selectTotal.Add(1) + m.latencyMsTotal.Add(decision.LatencyMs) + m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000))) + if decision.StickyPreviousHit { + m.stickyPreviousHitTotal.Add(1) + } + if decision.StickySessionHit { + m.stickySessionHitTotal.Add(1) + } + if decision.Layer == openAIAccountScheduleLayerLoadBalance { + m.loadBalanceSelectTotal.Add(1) + } +} + +func (m *openAIAccountSchedulerMetrics) recordSwitch() { + if m == nil { + return + } + m.accountSwitchTotal.Add(1) +} + +type openAIAccountRuntimeStats struct { + accounts sync.Map + accountCount atomic.Int64 +} + +type openAIAccountRuntimeStat struct { + errorRateEWMABits atomic.Uint64 + ttftEWMABits atomic.Uint64 +} + +func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats { + return &openAIAccountRuntimeStats{} +} + +func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat { + if value, ok := s.accounts.Load(accountID); ok { + stat, _ := value.(*openAIAccountRuntimeStat) + if stat != nil { + return stat + } + } + + stat := &openAIAccountRuntimeStat{} + stat.ttftEWMABits.Store(math.Float64bits(math.NaN())) + actual, loaded := s.accounts.LoadOrStore(accountID, stat) + if !loaded { + s.accountCount.Add(1) + return stat + } + existing, _ := actual.(*openAIAccountRuntimeStat) + if existing != nil { + return existing + } + return stat +} + +func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) { + for { + oldBits := target.Load() + oldValue := math.Float64frombits(oldBits) + newValue := alpha*sample + (1-alpha)*oldValue + if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + return + } + } +} + +func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) { + if s == nil || accountID <= 0 { + return + } + const alpha = 0.2 + stat := s.loadOrCreate(accountID) + + errorSample := 1.0 + if success { + errorSample = 0.0 + } + updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha) + + if firstTokenMs != nil && *firstTokenMs > 0 { + ttft := float64(*firstTokenMs) + ttftBits := math.Float64bits(ttft) + for { + oldBits := stat.ttftEWMABits.Load() + oldValue := math.Float64frombits(oldBits) + if math.IsNaN(oldValue) { + if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) { + break + } + continue + } + newValue := alpha*ttft + (1-alpha)*oldValue + if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + break + } + } + } +} + +func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) { + if s == nil || accountID <= 0 { + return 0, 0, false + } + value, ok := s.accounts.Load(accountID) + if !ok { + return 0, 0, false + } + stat, _ := value.(*openAIAccountRuntimeStat) + if stat == nil { + return 0, 0, false + } + errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load())) + ttftValue := math.Float64frombits(stat.ttftEWMABits.Load()) + if math.IsNaN(ttftValue) { + return errorRate, 0, false + } + return errorRate, ttftValue, true +} + +func (s *openAIAccountRuntimeStats) size() int { + if s == nil { + return 0 + } + return int(s.accountCount.Load()) +} + +type defaultOpenAIAccountScheduler struct { + service *OpenAIGatewayService + metrics openAIAccountSchedulerMetrics + stats *openAIAccountRuntimeStats +} + +func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler { + if stats == nil { + stats = newOpenAIAccountRuntimeStats() + } + return &defaultOpenAIAccountScheduler{ + service: service, + stats: stats, + } +} + +func (s *defaultOpenAIAccountScheduler) Select( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + decision := OpenAIAccountScheduleDecision{} + start := time.Now() + defer func() { + decision.LatencyMs = time.Since(start).Milliseconds() + s.metrics.recordSelect(decision) + }() + + previousResponseID := strings.TrimSpace(req.PreviousResponseID) + if previousResponseID != "" { + selection, err := s.service.SelectAccountByPreviousResponseID( + ctx, + req.GroupID, + previousResponseID, + req.RequestedModel, + req.ExcludedIDs, + ) + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) { + selection = nil + } + } + if selection != nil && selection.Account != nil { + decision.Layer = openAIAccountScheduleLayerPreviousResponse + decision.StickyPreviousHit = true + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID) + } + return selection, decision, nil + } + } + + selection, err := s.selectBySessionHash(ctx, req) + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + decision.Layer = openAIAccountScheduleLayerSessionSticky + decision.StickySessionHit = true + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + return selection, decision, nil + } + + selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req) + decision.Layer = openAIAccountScheduleLayerLoadBalance + decision.CandidateCount = candidateCount + decision.TopK = topK + decision.LoadSkew = loadSkew + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + } + return selection, decision, nil +} + +func (s *defaultOpenAIAccountScheduler) selectBySessionHash( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, error) { + sessionHash := strings.TrimSpace(req.SessionHash) + if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil { + return nil, nil + } + + accountID := req.StickyAccountID + if accountID <= 0 { + var err error + accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash) + if err != nil || accountID <= 0 { + return nil, nil + } + } + if accountID <= 0 { + return nil, nil + } + if req.ExcludedIDs != nil { + if _, excluded := req.ExcludedIDs[accountID]; excluded { + return nil, nil + } + } + + account, err := s.service.getSchedulableAccount(ctx, accountID) + if err != nil || account == nil { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + return nil, nil + } + if !s.isAccountTransportCompatible(account, req.RequiredTransport) { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if acquireErr == nil && result.Acquired { + _ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL()) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + cfg := s.service.schedulingConfig() + if s.service.concurrencyService != nil { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + return nil, nil +} + +type openAIAccountCandidateScore struct { + account *Account + loadInfo *AccountLoadInfo + score float64 + errorRate float64 + ttft float64 + hasTTFT bool +} + +type openAIAccountCandidateHeap []openAIAccountCandidateScore + +func (h openAIAccountCandidateHeap) Len() int { + return len(h) +} + +func (h openAIAccountCandidateHeap) Less(i, j int) bool { + // 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。 + return isOpenAIAccountCandidateBetter(h[j], h[i]) +} + +func (h openAIAccountCandidateHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *openAIAccountCandidateHeap) Push(x any) { + candidate, ok := x.(openAIAccountCandidateScore) + if !ok { + panic("openAIAccountCandidateHeap: invalid element type") + } + *h = append(*h, candidate) +} + +func (h *openAIAccountCandidateHeap) Pop() any { + old := *h + n := len(old) + last := old[n-1] + *h = old[:n-1] + return last +} + +func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool { + if left.score != right.score { + return left.score > right.score + } + if left.account.Priority != right.account.Priority { + return left.account.Priority < right.account.Priority + } + if left.loadInfo.LoadRate != right.loadInfo.LoadRate { + return left.loadInfo.LoadRate < right.loadInfo.LoadRate + } + if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount { + return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount + } + return left.account.ID < right.account.ID +} + +func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore { + if len(candidates) == 0 { + return nil + } + if topK <= 0 { + topK = 1 + } + if topK >= len(candidates) { + ranked := append([]openAIAccountCandidateScore(nil), candidates...) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + return ranked + } + + best := make(openAIAccountCandidateHeap, 0, topK) + for _, candidate := range candidates { + if len(best) < topK { + heap.Push(&best, candidate) + continue + } + if isOpenAIAccountCandidateBetter(candidate, best[0]) { + best[0] = candidate + heap.Fix(&best, 0) + } + } + + ranked := make([]openAIAccountCandidateScore, len(best)) + copy(ranked, best) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + return ranked +} + +type openAISelectionRNG struct { + state uint64 +} + +func newOpenAISelectionRNG(seed uint64) openAISelectionRNG { + if seed == 0 { + seed = 0x9e3779b97f4a7c15 + } + return openAISelectionRNG{state: seed} +} + +func (r *openAISelectionRNG) nextUint64() uint64 { + // xorshift64* + x := r.state + x ^= x >> 12 + x ^= x << 25 + x ^= x >> 27 + r.state = x + return x * 2685821657736338717 +} + +func (r *openAISelectionRNG) nextFloat64() float64 { + // [0,1) + return float64(r.nextUint64()>>11) / (1 << 53) +} + +func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 { + hasher := fnv.New64a() + writeValue := func(value string) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return + } + _, _ = hasher.Write([]byte(trimmed)) + _, _ = hasher.Write([]byte{0}) + } + + writeValue(req.SessionHash) + writeValue(req.PreviousResponseID) + writeValue(req.RequestedModel) + if req.GroupID != nil { + _, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10))) + } + + seed := hasher.Sum64() + // 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。 + if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" { + seed ^= uint64(time.Now().UnixNano()) + } + if seed == 0 { + seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15 + } + return seed +} + +func buildOpenAIWeightedSelectionOrder( + candidates []openAIAccountCandidateScore, + req OpenAIAccountScheduleRequest, +) []openAIAccountCandidateScore { + if len(candidates) <= 1 { + return append([]openAIAccountCandidateScore(nil), candidates...) + } + + pool := append([]openAIAccountCandidateScore(nil), candidates...) + weights := make([]float64, len(pool)) + minScore := pool[0].score + for i := 1; i < len(pool); i++ { + if pool[i].score < minScore { + minScore = pool[i].score + } + } + for i := range pool { + // 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。 + weight := (pool[i].score - minScore) + 1.0 + if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 { + weight = 1.0 + } + weights[i] = weight + } + + order := make([]openAIAccountCandidateScore, 0, len(pool)) + rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req)) + for len(pool) > 0 { + total := 0.0 + for _, w := range weights { + total += w + } + + selectedIdx := 0 + if total > 0 { + r := rng.nextFloat64() * total + acc := 0.0 + for i, w := range weights { + acc += w + if r <= acc { + selectedIdx = i + break + } + } + } else { + selectedIdx = int(rng.nextUint64() % uint64(len(pool))) + } + + order = append(order, pool[selectedIdx]) + pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...) + weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...) + } + return order +} + +func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, int, int, float64, error) { + accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID) + if err != nil { + return nil, 0, 0, 0, err + } + if len(accounts) == 0 { + return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + } + + filtered := make([]*Account, 0, len(accounts)) + loadReq := make([]AccountWithConcurrency, 0, len(accounts)) + for i := range accounts { + account := &accounts[i] + if req.ExcludedIDs != nil { + if _, excluded := req.ExcludedIDs[account.ID]; excluded { + continue + } + } + if !account.IsSchedulable() || !account.IsOpenAI() { + continue + } + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + continue + } + if !s.isAccountTransportCompatible(account, req.RequiredTransport) { + continue + } + filtered = append(filtered, account) + loadReq = append(loadReq, AccountWithConcurrency{ + ID: account.ID, + MaxConcurrency: account.Concurrency, + }) + } + if len(filtered) == 0 { + return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + } + + loadMap := map[int64]*AccountLoadInfo{} + if s.service.concurrencyService != nil { + if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil { + loadMap = batchLoad + } + } + + minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority + maxWaiting := 1 + loadRateSum := 0.0 + loadRateSumSquares := 0.0 + minTTFT, maxTTFT := 0.0, 0.0 + hasTTFTSample := false + candidates := make([]openAIAccountCandidateScore, 0, len(filtered)) + for _, account := range filtered { + loadInfo := loadMap[account.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: account.ID} + } + if account.Priority < minPriority { + minPriority = account.Priority + } + if account.Priority > maxPriority { + maxPriority = account.Priority + } + if loadInfo.WaitingCount > maxWaiting { + maxWaiting = loadInfo.WaitingCount + } + errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID) + if hasTTFT && ttft > 0 { + if !hasTTFTSample { + minTTFT, maxTTFT = ttft, ttft + hasTTFTSample = true + } else { + if ttft < minTTFT { + minTTFT = ttft + } + if ttft > maxTTFT { + maxTTFT = ttft + } + } + } + loadRate := float64(loadInfo.LoadRate) + loadRateSum += loadRate + loadRateSumSquares += loadRate * loadRate + candidates = append(candidates, openAIAccountCandidateScore{ + account: account, + loadInfo: loadInfo, + errorRate: errorRate, + ttft: ttft, + hasTTFT: hasTTFT, + }) + } + loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates)) + + weights := s.service.openAIWSSchedulerWeights() + for i := range candidates { + item := &candidates[i] + priorityFactor := 1.0 + if maxPriority > minPriority { + priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority) + } + loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0) + queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting)) + errorFactor := 1 - clamp01(item.errorRate) + ttftFactor := 0.5 + if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { + ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) + } + + item.score = weights.Priority*priorityFactor + + weights.Load*loadFactor + + weights.Queue*queueFactor + + weights.ErrorRate*errorFactor + + weights.TTFT*ttftFactor + } + + topK := s.service.openAIWSLBTopK() + if topK > len(candidates) { + topK = len(candidates) + } + if topK <= 0 { + topK = 1 + } + rankedCandidates := selectTopKOpenAICandidates(candidates, topK) + selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req) + + for i := 0; i < len(selectionOrder); i++ { + candidate := selectionOrder[i] + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency) + if acquireErr != nil { + return nil, len(candidates), topK, loadSkew, acquireErr + } + if result != nil && result.Acquired { + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID) + } + return &AccountSelectionResult{ + Account: candidate.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, len(candidates), topK, loadSkew, nil + } + } + + cfg := s.service.schedulingConfig() + candidate := selectionOrder[0] + return &AccountSelectionResult{ + Account: candidate.account, + WaitPlan: &AccountWaitPlan{ + AccountID: candidate.account.ID, + MaxConcurrency: candidate.account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, len(candidates), topK, loadSkew, nil +} + +func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { + // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。 + if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { + return true + } + if s == nil || s.service == nil || account == nil { + return false + } + return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport +} + +func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) { + if s == nil || s.stats == nil { + return + } + s.stats.report(accountID, success, firstTokenMs) +} + +func (s *defaultOpenAIAccountScheduler) ReportSwitch() { + if s == nil { + return + } + s.metrics.recordSwitch() +} + +func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot { + if s == nil { + return OpenAIAccountSchedulerMetricsSnapshot{} + } + + selectTotal := s.metrics.selectTotal.Load() + prevHit := s.metrics.stickyPreviousHitTotal.Load() + sessionHit := s.metrics.stickySessionHitTotal.Load() + switchTotal := s.metrics.accountSwitchTotal.Load() + latencyTotal := s.metrics.latencyMsTotal.Load() + loadSkewTotal := s.metrics.loadSkewMilliTotal.Load() + + snapshot := OpenAIAccountSchedulerMetricsSnapshot{ + SelectTotal: selectTotal, + StickyPreviousHitTotal: prevHit, + StickySessionHitTotal: sessionHit, + LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(), + AccountSwitchTotal: switchTotal, + SchedulerLatencyMsTotal: latencyTotal, + RuntimeStatsAccountCount: s.stats.size(), + } + if selectTotal > 0 { + snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal) + snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal) + snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal) + snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal) + } + return snapshot +} + +func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler { + if s == nil { + return nil + } + s.openaiSchedulerOnce.Do(func() { + if s.openaiAccountStats == nil { + s.openaiAccountStats = newOpenAIAccountRuntimeStats() + } + if s.openaiScheduler == nil { + s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats) + } + }) + return s.openaiScheduler +} + +func (s *OpenAIGatewayService) SelectAccountWithScheduler( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport OpenAIUpstreamTransport, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + decision := OpenAIAccountScheduleDecision{} + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) + decision.Layer = openAIAccountScheduleLayerLoadBalance + return selection, decision, err + } + + var stickyAccountID int64 + if sessionHash != "" && s.cache != nil { + if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 { + stickyAccountID = accountID + } + } + + return scheduler.Select(ctx, OpenAIAccountScheduleRequest{ + GroupID: groupID, + SessionHash: sessionHash, + StickyAccountID: stickyAccountID, + PreviousResponseID: previousResponseID, + RequestedModel: requestedModel, + RequiredTransport: requiredTransport, + ExcludedIDs: excludedIDs, + }) +} + +func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return + } + scheduler.ReportResult(accountID, success, firstTokenMs) +} + +func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return + } + scheduler.ReportSwitch() +} + +func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return OpenAIAccountSchedulerMetricsSnapshot{} + } + return scheduler.SnapshotMetrics() +} + +func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second + } + return openaiStickySessionTTL +} + +func (s *OpenAIGatewayService) openAIWSLBTopK() int { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 { + return s.cfg.Gateway.OpenAIWS.LBTopK + } + return 7 +} + +func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView { + if s != nil && s.cfg != nil { + return GatewayOpenAIWSSchedulerScoreWeightsView{ + Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority, + Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load, + Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue, + ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate, + TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT, + } + } + return GatewayOpenAIWSSchedulerScoreWeightsView{ + Priority: 1.0, + Load: 1.0, + Queue: 0.7, + ErrorRate: 0.8, + TTFT: 0.5, + } +} + +type GatewayOpenAIWSSchedulerScoreWeightsView struct { + Priority float64 + Load float64 + Queue float64 + ErrorRate float64 + TTFT float64 +} + +func clamp01(value float64) float64 { + switch { + case value < 0: + return 0 + case value > 1: + return 1 + default: + return value + } +} + +func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 { + if count <= 1 { + return 0 + } + mean := sum / float64(count) + variance := sumSquares/float64(count) - mean*mean + if variance < 0 { + variance = 0 + } + return math.Sqrt(variance) +} diff --git a/backend/internal/service/openai_account_scheduler_benchmark_test.go b/backend/internal/service/openai_account_scheduler_benchmark_test.go new file mode 100644 index 00000000..897be5b0 --- /dev/null +++ b/backend/internal/service/openai_account_scheduler_benchmark_test.go @@ -0,0 +1,83 @@ +package service + +import ( + "sort" + "testing" +) + +func buildOpenAISchedulerBenchmarkCandidates(size int) []openAIAccountCandidateScore { + if size <= 0 { + return nil + } + candidates := make([]openAIAccountCandidateScore, 0, size) + for i := 0; i < size; i++ { + accountID := int64(10_000 + i) + candidates = append(candidates, openAIAccountCandidateScore{ + account: &Account{ + ID: accountID, + Priority: i % 7, + }, + loadInfo: &AccountLoadInfo{ + AccountID: accountID, + LoadRate: (i * 17) % 100, + WaitingCount: (i * 11) % 13, + }, + score: float64((i*29)%1000) / 100, + errorRate: float64((i * 5) % 100 / 100), + ttft: float64(30 + (i*3)%500), + hasTTFT: i%3 != 0, + }) + } + return candidates +} + +func selectTopKOpenAICandidatesBySortBenchmark(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore { + if len(candidates) == 0 { + return nil + } + if topK <= 0 { + topK = 1 + } + ranked := append([]openAIAccountCandidateScore(nil), candidates...) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + if topK > len(ranked) { + topK = len(ranked) + } + return ranked[:topK] +} + +func BenchmarkOpenAIAccountSchedulerSelectTopK(b *testing.B) { + cases := []struct { + name string + size int + topK int + }{ + {name: "n_16_k_3", size: 16, topK: 3}, + {name: "n_64_k_3", size: 64, topK: 3}, + {name: "n_256_k_5", size: 256, topK: 5}, + } + + for _, tc := range cases { + candidates := buildOpenAISchedulerBenchmarkCandidates(tc.size) + b.Run(tc.name+"/heap_topk", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + result := selectTopKOpenAICandidates(candidates, tc.topK) + if len(result) == 0 { + b.Fatal("unexpected empty result") + } + } + }) + b.Run(tc.name+"/full_sort", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + result := selectTopKOpenAICandidatesBySortBenchmark(candidates, tc.topK) + if len(result) == 0 { + b.Fatal("unexpected empty result") + } + } + }) + } +} diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go new file mode 100644 index 00000000..7f6f1b66 --- /dev/null +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -0,0 +1,841 @@ +package service + +import ( + "context" + "fmt" + "math" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(9) + account := Account{ + ID: 1001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_001", account.ID, time.Hour)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_prev_001", + "session_hash_001", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer) + require.True(t, decision.StickyPreviousHit) + require.Equal(t, account.ID, cache.sessionBindings["openai:session_hash_001"]) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(10) + account := Account{ + ID: 2001, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_abc": account.ID, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_abc", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(10100) + accounts := []Account{ + { + ID: 21001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 21002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 9, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_sticky_busy": 21001, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2 + cfg.Gateway.Scheduling.StickySessionWaitTimeout = 45 * time.Second + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + concurrencyCache := stubConcurrencyCache{ + acquireResults: map[int64]bool{ + 21001: false, // sticky 账号已满 + 21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换) + }, + waitCounts: map[int64]int{ + 21001: 999, + }, + loadMap: map[int64]*AccountLoadInfo{ + 21001: {AccountID: 21001, LoadRate: 90, WaitingCount: 9}, + 21002: {AccountID: 21002, LoadRate: 1, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_sticky_busy", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(21001), selection.Account.ID, "busy sticky account should remain selected") + require.False(t, selection.Acquired) + require.NotNil(t, selection.WaitPlan) + require.Equal(t, int64(21001), selection.WaitPlan.AccountID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP(t *testing.T) { + ctx := context.Background() + groupID := int64(1010) + account := Account{ + ID: 2101, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_ws_force_http": true, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_force_http": account.ID, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_force_http", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStickyHTTPAccount(t *testing.T) { + ctx := context.Background() + groupID := int64(1011) + accounts := []Account{ + { + ID: 2201, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 2202, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_ws_only": 2201, + }, + } + cfg := newOpenAIWSV2TestConfig() + + // 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。 + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0}, + 2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_ws_only", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(2202), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.False(t, decision.StickySessionHit) + require.Equal(t, 1, decision.CandidateCount) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailableAccount(t *testing.T) { + ctx := context.Background() + groupID := int64(1012) + accounts := []Account{ + { + ID: 2301, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: newOpenAIWSV2TestConfig(), + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.Error(t, err) + require.Nil(t, selection) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, 0, decision.CandidateCount) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback(t *testing.T) { + ctx := context.Background() + groupID := int64(11) + accounts := []Account{ + { + ID: 3001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 3002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 3003, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.4 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8}, + 3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1}, + 3003: {AccountID: 3003, LoadRate: 10, WaitingCount: 0}, + }, + acquireResults: map[int64]bool{ + 3003: false, // top1 失败,必须回退到 top-K 的下一候选 + 3002: true, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(3002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, 3, decision.CandidateCount) + require.Equal(t, 2, decision.TopK) + require.Greater(t, decision.LoadSkew, 0.0) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { + ctx := context.Background() + groupID := int64(12) + account := Account{ + ID: 4001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_metrics": account.ID, + }, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120)) + svc.RecordOpenAIAccountSwitch() + + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.GreaterOrEqual(t, snapshot.SelectTotal, int64(1)) + require.GreaterOrEqual(t, snapshot.StickySessionHitTotal, int64(1)) + require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.GreaterOrEqual(t, snapshot.SchedulerLatencyMsAvg, float64(0)) + require.GreaterOrEqual(t, snapshot.StickyHitRatio, 0.0) + require.GreaterOrEqual(t, snapshot.RuntimeStatsAccountCount, 1) +} + +func intPtrForTest(v int) *int { + return &v +} + +func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + stats.report(1001, true, nil) + firstTTFT := 100 + stats.report(1001, false, &firstTTFT) + secondTTFT := 200 + stats.report(1001, false, &secondTTFT) + + errorRate, ttft, hasTTFT := stats.snapshot(1001) + require.True(t, hasTTFT) + require.InDelta(t, 0.36, errorRate, 1e-9) + require.InDelta(t, 120.0, ttft, 1e-9) + require.Equal(t, 1, stats.size()) +} + +func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + + const ( + accountCount = 4 + workers = 16 + iterations = 800 + ) + var wg sync.WaitGroup + wg.Add(workers) + for worker := 0; worker < workers; worker++ { + worker := worker + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + accountID := int64(i%accountCount + 1) + success := (i+worker)%3 != 0 + ttft := 80 + (i+worker)%40 + stats.report(accountID, success, &ttft) + } + }() + } + wg.Wait() + + require.Equal(t, accountCount, stats.size()) + for accountID := int64(1); accountID <= accountCount; accountID++ { + errorRate, ttft, hasTTFT := stats.snapshot(accountID) + require.GreaterOrEqual(t, errorRate, 0.0) + require.LessOrEqual(t, errorRate, 1.0) + require.True(t, hasTTFT) + require.Greater(t, ttft, 0.0) + } +} + +func TestSelectTopKOpenAICandidates(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + { + account: &Account{ID: 11, Priority: 2}, + loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 1}, + score: 10.0, + }, + { + account: &Account{ID: 12, Priority: 1}, + loadInfo: &AccountLoadInfo{LoadRate: 20, WaitingCount: 1}, + score: 9.5, + }, + { + account: &Account{ID: 13, Priority: 1}, + loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 0}, + score: 10.0, + }, + { + account: &Account{ID: 14, Priority: 0}, + loadInfo: &AccountLoadInfo{LoadRate: 40, WaitingCount: 0}, + score: 8.0, + }, + } + + top2 := selectTopKOpenAICandidates(candidates, 2) + require.Len(t, top2, 2) + require.Equal(t, int64(13), top2[0].account.ID) + require.Equal(t, int64(11), top2[1].account.ID) + + topAll := selectTopKOpenAICandidates(candidates, 8) + require.Len(t, topAll, len(candidates)) + require.Equal(t, int64(13), topAll[0].account.ID) + require.Equal(t, int64(11), topAll[1].account.ID) + require.Equal(t, int64(12), topAll[2].account.ID) + require.Equal(t, int64(14), topAll[3].account.ID) +} + +func TestBuildOpenAIWeightedSelectionOrder_DeterministicBySessionSeed(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + { + account: &Account{ID: 101}, + loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 0}, + score: 4.2, + }, + { + account: &Account{ID: 102}, + loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 1}, + score: 3.5, + }, + { + account: &Account{ID: 103}, + loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2}, + score: 2.1, + }, + } + req := OpenAIAccountScheduleRequest{ + GroupID: int64PtrForTest(99), + SessionHash: "session_seed_fixed", + RequestedModel: "gpt-5.1", + } + + first := buildOpenAIWeightedSelectionOrder(candidates, req) + second := buildOpenAIWeightedSelectionOrder(candidates, req) + require.Len(t, first, len(candidates)) + require.Len(t, second, len(candidates)) + for i := range first { + require.Equal(t, first[i].account.ID, second[i].account.ID) + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesAcrossSessions(t *testing.T) { + ctx := context.Background() + groupID := int64(15) + accounts := []Account{ + { + ID: 5101, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 0, + }, + { + ID: 5102, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 0, + }, + { + ID: 5103, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 3 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1}, + 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1}, + 5103: {AccountID: 5103, LoadRate: 20, WaitingCount: 1}, + }, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selected := make(map[int64]int, len(accounts)) + for i := 0; i < 60; i++ { + sessionHash := fmt.Sprintf("session_hash_lb_%d", i) + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + selected[selection.Account.ID]++ + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + // 多 session 应该能打散到多个账号,避免“恒定单账号命中”。 + require.GreaterOrEqual(t, len(selected), 2) +} + +func TestDeriveOpenAISelectionSeed_NoAffinityAddsEntropy(t *testing.T) { + req := OpenAIAccountScheduleRequest{ + RequestedModel: "gpt-5.1", + } + seed1 := deriveOpenAISelectionSeed(req) + time.Sleep(1 * time.Millisecond) + seed2 := deriveOpenAISelectionSeed(req) + require.NotZero(t, seed1) + require.NotZero(t, seed2) + require.NotEqual(t, seed1, seed2) +} + +func TestBuildOpenAIWeightedSelectionOrder_HandlesInvalidScores(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + { + account: &Account{ID: 901}, + loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0}, + score: math.NaN(), + }, + { + account: &Account{ID: 902}, + loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0}, + score: math.Inf(1), + }, + { + account: &Account{ID: 903}, + loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0}, + score: -1, + }, + } + req := OpenAIAccountScheduleRequest{ + SessionHash: "seed_invalid_scores", + } + + order := buildOpenAIWeightedSelectionOrder(candidates, req) + require.Len(t, order, len(candidates)) + seen := map[int64]struct{}{} + for _, item := range order { + seen[item.account.ID] = struct{}{} + } + require.Len(t, seen, len(candidates)) +} + +func TestOpenAISelectionRNG_SeedZeroStillWorks(t *testing.T) { + rng := newOpenAISelectionRNG(0) + v1 := rng.nextUint64() + v2 := rng.nextUint64() + require.NotEqual(t, v1, v2) + require.GreaterOrEqual(t, rng.nextFloat64(), 0.0) + require.Less(t, rng.nextFloat64(), 1.0) +} + +func TestOpenAIAccountCandidateHeap_PushPopAndInvalidType(t *testing.T) { + h := openAIAccountCandidateHeap{} + h.Push(openAIAccountCandidateScore{ + account: &Account{ID: 7001}, + loadInfo: &AccountLoadInfo{LoadRate: 0, WaitingCount: 0}, + score: 1.0, + }) + require.Equal(t, 1, h.Len()) + popped, ok := h.Pop().(openAIAccountCandidateScore) + require.True(t, ok) + require.Equal(t, int64(7001), popped.account.ID) + require.Equal(t, 0, h.Len()) + + require.Panics(t, func() { + h.Push("bad_element_type") + }) +} + +func TestClamp01_AllBranches(t *testing.T) { + require.Equal(t, 0.0, clamp01(-0.2)) + require.Equal(t, 1.0, clamp01(1.3)) + require.Equal(t, 0.5, clamp01(0.5)) +} + +func TestCalcLoadSkewByMoments_Branches(t *testing.T) { + require.Equal(t, 0.0, calcLoadSkewByMoments(1, 1, 1)) + // variance < 0 分支:sumSquares/count - mean^2 为负值时应钳制为 0。 + require.Equal(t, 0.0, calcLoadSkewByMoments(1, 0, 2)) + require.GreaterOrEqual(t, calcLoadSkewByMoments(6, 20, 3), 0.0) +} + +func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) { + schedulerAny := newDefaultOpenAIAccountScheduler(&OpenAIGatewayService{}, nil) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + ttft := 100 + scheduler.ReportResult(1001, true, &ttft) + scheduler.ReportSwitch() + scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{ + Layer: openAIAccountScheduleLayerLoadBalance, + LatencyMs: 8, + LoadSkew: 0.5, + StickyPreviousHit: true, + }) + scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{ + Layer: openAIAccountScheduleLayerSessionSticky, + LatencyMs: 6, + LoadSkew: 0.2, + StickySessionHit: true, + }) + + snapshot := scheduler.SnapshotMetrics() + require.Equal(t, int64(2), snapshot.SelectTotal) + require.Equal(t, int64(1), snapshot.StickyPreviousHitTotal) + require.Equal(t, int64(1), snapshot.StickySessionHitTotal) + require.Equal(t, int64(1), snapshot.LoadBalanceSelectTotal) + require.Equal(t, int64(1), snapshot.AccountSwitchTotal) + require.Greater(t, snapshot.SchedulerLatencyMsAvg, 0.0) + require.Greater(t, snapshot.StickyHitRatio, 0.0) + require.Greater(t, snapshot.LoadSkewAvg, 0.0) +} + +func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) { + svc := &OpenAIGatewayService{} + ttft := 120 + svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) + svc.RecordOpenAIAccountSwitch() + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.Equal(t, 7, svc.openAIWSLBTopK()) + require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL()) + + defaultWeights := svc.openAIWSSchedulerWeights() + require.Equal(t, 1.0, defaultWeights.Priority) + require.Equal(t, 1.0, defaultWeights.Load) + require.Equal(t, 0.7, defaultWeights.Queue) + require.Equal(t, 0.8, defaultWeights.ErrorRate) + require.Equal(t, 0.5, defaultWeights.TTFT) + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 9 + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 180 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.3 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.4 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.5 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.6 + svcWithCfg := &OpenAIGatewayService{cfg: cfg} + + require.Equal(t, 9, svcWithCfg.openAIWSLBTopK()) + require.Equal(t, 180*time.Second, svcWithCfg.openAIWSSessionStickyTTL()) + customWeights := svcWithCfg.openAIWSSchedulerWeights() + require.Equal(t, 0.2, customWeights.Priority) + require.Equal(t, 0.3, customWeights.Load) + require.Equal(t, 0.4, customWeights.Queue) + require.Equal(t, 0.5, customWeights.ErrorRate) + require.Equal(t, 0.6, customWeights.TTFT) +} + +func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *testing.T) { + scheduler := &defaultOpenAIAccountScheduler{} + require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportAny)) + require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE)) + require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2)) + + cfg := newOpenAIWSV2TestConfig() + scheduler.service = &OpenAIGatewayService{cfg: cfg} + account := &Account{ + ID: 8801, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2)) +} + +func int64PtrForTest(v int64) *int64 { + return &v +} diff --git a/backend/internal/service/openai_client_restriction_detector.go b/backend/internal/service/openai_client_restriction_detector.go new file mode 100644 index 00000000..d1784e11 --- /dev/null +++ b/backend/internal/service/openai_client_restriction_detector.go @@ -0,0 +1,86 @@ +package service + +import ( + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/gin-gonic/gin" +) + +const ( + // CodexClientRestrictionReasonDisabled 表示账号未开启 codex_cli_only。 + CodexClientRestrictionReasonDisabled = "codex_cli_only_disabled" + // CodexClientRestrictionReasonMatchedUA 表示请求命中官方客户端 UA 白名单。 + CodexClientRestrictionReasonMatchedUA = "official_client_user_agent_matched" + // CodexClientRestrictionReasonMatchedOriginator 表示请求命中官方客户端 originator 白名单。 + CodexClientRestrictionReasonMatchedOriginator = "official_client_originator_matched" + // CodexClientRestrictionReasonNotMatchedUA 表示请求未命中官方客户端 UA 白名单。 + CodexClientRestrictionReasonNotMatchedUA = "official_client_user_agent_not_matched" + // CodexClientRestrictionReasonForceCodexCLI 表示通过 ForceCodexCLI 配置兜底放行。 + CodexClientRestrictionReasonForceCodexCLI = "force_codex_cli_enabled" +) + +// CodexClientRestrictionDetectionResult 是 codex_cli_only 统一检测入口结果。 +type CodexClientRestrictionDetectionResult struct { + Enabled bool + Matched bool + Reason string +} + +// CodexClientRestrictionDetector 定义 codex_cli_only 统一检测入口。 +type CodexClientRestrictionDetector interface { + Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult +} + +// OpenAICodexClientRestrictionDetector 为 OpenAI OAuth codex_cli_only 的默认实现。 +type OpenAICodexClientRestrictionDetector struct { + cfg *config.Config +} + +func NewOpenAICodexClientRestrictionDetector(cfg *config.Config) *OpenAICodexClientRestrictionDetector { + return &OpenAICodexClientRestrictionDetector{cfg: cfg} +} + +func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult { + if account == nil || !account.IsCodexCLIOnlyEnabled() { + return CodexClientRestrictionDetectionResult{ + Enabled: false, + Matched: false, + Reason: CodexClientRestrictionReasonDisabled, + } + } + + if d != nil && d.cfg != nil && d.cfg.Gateway.ForceCodexCLI { + return CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: true, + Reason: CodexClientRestrictionReasonForceCodexCLI, + } + } + + userAgent := "" + originator := "" + if c != nil { + userAgent = c.GetHeader("User-Agent") + originator = c.GetHeader("originator") + } + if openai.IsCodexOfficialClientRequest(userAgent) { + return CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: true, + Reason: CodexClientRestrictionReasonMatchedUA, + } + } + if openai.IsCodexOfficialClientOriginator(originator) { + return CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: true, + Reason: CodexClientRestrictionReasonMatchedOriginator, + } + } + + return CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: false, + Reason: CodexClientRestrictionReasonNotMatchedUA, + } +} diff --git a/backend/internal/service/openai_client_restriction_detector_test.go b/backend/internal/service/openai_client_restriction_detector_test.go new file mode 100644 index 00000000..984b4ff6 --- /dev/null +++ b/backend/internal/service/openai_client_restriction_detector_test.go @@ -0,0 +1,124 @@ +package service + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func newCodexDetectorTestContext(ua string, originator string) *gin.Context { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + if ua != "" { + c.Request.Header.Set("User-Agent", ua) + } + if originator != "" { + c.Request.Header.Set("originator", originator) + } + return c +} + +func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("未开启开关时绕过", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}} + + result := detector.Detect(newCodexDetectorTestContext("curl/8.0", ""), account) + require.False(t, result.Enabled) + require.False(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonDisabled, result.Reason) + }) + + t.Run("开启后 codex_cli_rs 命中", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("codex_cli_rs/0.99.0", ""), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason) + }) + + t.Run("开启后 codex_vscode 命中", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("codex_vscode/1.0.0", ""), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason) + }) + + t.Run("开启后 codex_app 命中", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("codex_app/2.1.0", ""), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason) + }) + + t.Run("开启后 originator 命中", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "codex_chatgpt_desktop"), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedOriginator, result.Reason) + }) + + t.Run("开启后非官方客户端拒绝", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account) + require.True(t, result.Enabled) + require.False(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason) + }) + + t.Run("开启 ForceCodexCLI 时允许通过", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(&config.Config{ + Gateway: config.GatewayConfig{ForceCodexCLI: true}, + }) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason) + }) +} diff --git a/backend/internal/service/openai_client_transport.go b/backend/internal/service/openai_client_transport.go new file mode 100644 index 00000000..c9cf3246 --- /dev/null +++ b/backend/internal/service/openai_client_transport.go @@ -0,0 +1,71 @@ +package service + +import ( + "strings" + + "github.com/gin-gonic/gin" +) + +// OpenAIClientTransport 表示客户端入站协议类型。 +type OpenAIClientTransport string + +const ( + OpenAIClientTransportUnknown OpenAIClientTransport = "" + OpenAIClientTransportHTTP OpenAIClientTransport = "http" + OpenAIClientTransportWS OpenAIClientTransport = "ws" +) + +const openAIClientTransportContextKey = "openai_client_transport" + +// SetOpenAIClientTransport 标记当前请求的客户端入站协议。 +func SetOpenAIClientTransport(c *gin.Context, transport OpenAIClientTransport) { + if c == nil { + return + } + normalized := normalizeOpenAIClientTransport(transport) + if normalized == OpenAIClientTransportUnknown { + return + } + c.Set(openAIClientTransportContextKey, string(normalized)) +} + +// GetOpenAIClientTransport 读取当前请求的客户端入站协议。 +func GetOpenAIClientTransport(c *gin.Context) OpenAIClientTransport { + if c == nil { + return OpenAIClientTransportUnknown + } + raw, ok := c.Get(openAIClientTransportContextKey) + if !ok || raw == nil { + return OpenAIClientTransportUnknown + } + + switch v := raw.(type) { + case OpenAIClientTransport: + return normalizeOpenAIClientTransport(v) + case string: + return normalizeOpenAIClientTransport(OpenAIClientTransport(v)) + default: + return OpenAIClientTransportUnknown + } +} + +func normalizeOpenAIClientTransport(transport OpenAIClientTransport) OpenAIClientTransport { + switch strings.ToLower(strings.TrimSpace(string(transport))) { + case string(OpenAIClientTransportHTTP), "http_sse", "sse": + return OpenAIClientTransportHTTP + case string(OpenAIClientTransportWS), "websocket": + return OpenAIClientTransportWS + default: + return OpenAIClientTransportUnknown + } +} + +func resolveOpenAIWSDecisionByClientTransport( + decision OpenAIWSProtocolDecision, + clientTransport OpenAIClientTransport, +) OpenAIWSProtocolDecision { + if clientTransport == OpenAIClientTransportHTTP { + return openAIWSHTTPDecision("client_protocol_http") + } + return decision +} diff --git a/backend/internal/service/openai_client_transport_test.go b/backend/internal/service/openai_client_transport_test.go new file mode 100644 index 00000000..ef90e614 --- /dev/null +++ b/backend/internal/service/openai_client_transport_test.go @@ -0,0 +1,107 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestOpenAIClientTransport_SetAndGet(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(c)) + + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) + require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c)) + + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + require.Equal(t, OpenAIClientTransportWS, GetOpenAIClientTransport(c)) +} + +func TestOpenAIClientTransport_GetNormalizesRawContextValue(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + rawValue any + want OpenAIClientTransport + }{ + { + name: "type_value_ws", + rawValue: OpenAIClientTransportWS, + want: OpenAIClientTransportWS, + }, + { + name: "http_sse_alias", + rawValue: "http_sse", + want: OpenAIClientTransportHTTP, + }, + { + name: "sse_alias", + rawValue: "sSe", + want: OpenAIClientTransportHTTP, + }, + { + name: "websocket_alias", + rawValue: "WebSocket", + want: OpenAIClientTransportWS, + }, + { + name: "invalid_string", + rawValue: "tcp", + want: OpenAIClientTransportUnknown, + }, + { + name: "invalid_type", + rawValue: 123, + want: OpenAIClientTransportUnknown, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Set(openAIClientTransportContextKey, tt.rawValue) + require.Equal(t, tt.want, GetOpenAIClientTransport(c)) + }) + } +} + +func TestOpenAIClientTransport_NilAndUnknownInput(t *testing.T) { + SetOpenAIClientTransport(nil, OpenAIClientTransportHTTP) + require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(nil)) + + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + SetOpenAIClientTransport(c, OpenAIClientTransportUnknown) + _, exists := c.Get(openAIClientTransportContextKey) + require.False(t, exists) + + SetOpenAIClientTransport(c, OpenAIClientTransport(" ")) + _, exists = c.Get(openAIClientTransportContextKey) + require.False(t, exists) +} + +func TestResolveOpenAIWSDecisionByClientTransport(t *testing.T) { + base := OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_enabled", + } + + httpDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportHTTP) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, httpDecision.Transport) + require.Equal(t, "client_protocol_http", httpDecision.Reason) + + wsDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportWS) + require.Equal(t, base, wsDecision) + + unknownDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportUnknown) + require.Equal(t, base, unknownDecision) +} diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index cea81693..16befb82 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -2,73 +2,66 @@ package service import ( _ "embed" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" "strings" - "time" -) - -const ( - opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt" - codexCacheTTL = 15 * time.Minute ) //go:embed prompts/codex_cli_instructions.md var codexCLIInstructions string var codexModelMap = map[string]string{ - "gpt-5.3": "gpt-5.3", - "gpt-5.3-none": "gpt-5.3", - "gpt-5.3-low": "gpt-5.3", - "gpt-5.3-medium": "gpt-5.3", - "gpt-5.3-high": "gpt-5.3", - "gpt-5.3-xhigh": "gpt-5.3", - "gpt-5.3-codex": "gpt-5.3-codex", - "gpt-5.3-codex-low": "gpt-5.3-codex", - "gpt-5.3-codex-medium": "gpt-5.3-codex", - "gpt-5.3-codex-high": "gpt-5.3-codex", - "gpt-5.3-codex-xhigh": "gpt-5.3-codex", - "gpt-5.1-codex": "gpt-5.1-codex", - "gpt-5.1-codex-low": "gpt-5.1-codex", - "gpt-5.1-codex-medium": "gpt-5.1-codex", - "gpt-5.1-codex-high": "gpt-5.1-codex", - "gpt-5.1-codex-max": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-low": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-high": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max", - "gpt-5.2": "gpt-5.2", - "gpt-5.2-none": "gpt-5.2", - "gpt-5.2-low": "gpt-5.2", - "gpt-5.2-medium": "gpt-5.2", - "gpt-5.2-high": "gpt-5.2", - "gpt-5.2-xhigh": "gpt-5.2", - "gpt-5.2-codex": "gpt-5.2-codex", - "gpt-5.2-codex-low": "gpt-5.2-codex", - "gpt-5.2-codex-medium": "gpt-5.2-codex", - "gpt-5.2-codex-high": "gpt-5.2-codex", - "gpt-5.2-codex-xhigh": "gpt-5.2-codex", - "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", - "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini", - "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini", - "gpt-5.1": "gpt-5.1", - "gpt-5.1-none": "gpt-5.1", - "gpt-5.1-low": "gpt-5.1", - "gpt-5.1-medium": "gpt-5.1", - "gpt-5.1-high": "gpt-5.1", - "gpt-5.1-chat-latest": "gpt-5.1", - "gpt-5-codex": "gpt-5.1-codex", - "codex-mini-latest": "gpt-5.1-codex-mini", - "gpt-5-codex-mini": "gpt-5.1-codex-mini", - "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini", - "gpt-5-codex-mini-high": "gpt-5.1-codex-mini", - "gpt-5": "gpt-5.1", - "gpt-5-mini": "gpt-5.1", - "gpt-5-nano": "gpt-5.1", + "gpt-5.3": "gpt-5.3-codex", + "gpt-5.3-none": "gpt-5.3-codex", + "gpt-5.3-low": "gpt-5.3-codex", + "gpt-5.3-medium": "gpt-5.3-codex", + "gpt-5.3-high": "gpt-5.3-codex", + "gpt-5.3-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex", + "gpt-5.3-codex-spark-low": "gpt-5.3-codex", + "gpt-5.3-codex-spark-medium": "gpt-5.3-codex", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-low": "gpt-5.3-codex", + "gpt-5.3-codex-medium": "gpt-5.3-codex", + "gpt-5.3-codex-high": "gpt-5.3-codex", + "gpt-5.3-codex-xhigh": "gpt-5.3-codex", + "gpt-5.1-codex": "gpt-5.1-codex", + "gpt-5.1-codex-low": "gpt-5.1-codex", + "gpt-5.1-codex-medium": "gpt-5.1-codex", + "gpt-5.1-codex-high": "gpt-5.1-codex", + "gpt-5.1-codex-max": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-low": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-high": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max", + "gpt-5.2": "gpt-5.2", + "gpt-5.2-none": "gpt-5.2", + "gpt-5.2-low": "gpt-5.2", + "gpt-5.2-medium": "gpt-5.2", + "gpt-5.2-high": "gpt-5.2", + "gpt-5.2-xhigh": "gpt-5.2", + "gpt-5.2-codex": "gpt-5.2-codex", + "gpt-5.2-codex-low": "gpt-5.2-codex", + "gpt-5.2-codex-medium": "gpt-5.2-codex", + "gpt-5.2-codex-high": "gpt-5.2-codex", + "gpt-5.2-codex-xhigh": "gpt-5.2-codex", + "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", + "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini", + "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini", + "gpt-5.1": "gpt-5.1", + "gpt-5.1-none": "gpt-5.1", + "gpt-5.1-low": "gpt-5.1", + "gpt-5.1-medium": "gpt-5.1", + "gpt-5.1-high": "gpt-5.1", + "gpt-5.1-chat-latest": "gpt-5.1", + "gpt-5-codex": "gpt-5.1-codex", + "codex-mini-latest": "gpt-5.1-codex-mini", + "gpt-5-codex-mini": "gpt-5.1-codex-mini", + "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini", + "gpt-5-codex-mini-high": "gpt-5.1-codex-mini", + "gpt-5": "gpt-5.1", + "gpt-5-mini": "gpt-5.1", + "gpt-5-nano": "gpt-5.1", } type codexTransformResult struct { @@ -77,12 +70,6 @@ type codexTransformResult struct { PromptCacheKey string } -type opencodeCacheMetadata struct { - ETag string `json:"etag"` - LastFetch string `json:"lastFetch,omitempty"` - LastChecked int64 `json:"lastChecked"` -} - func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult { result := codexTransformResult{} // 工具续链需求会影响存储策略与 input 过滤逻辑。 @@ -112,13 +99,19 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran result.Modified = true } - if _, ok := reqBody["max_output_tokens"]; ok { - delete(reqBody, "max_output_tokens") - result.Modified = true - } - if _, ok := reqBody["max_completion_tokens"]; ok { - delete(reqBody, "max_completion_tokens") - result.Modified = true + // Strip parameters unsupported by codex models via the Responses API. + for _, key := range []string{ + "max_output_tokens", + "max_completion_tokens", + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", + } { + if _, ok := reqBody[key]; ok { + delete(reqBody, key) + result.Modified = true + } } if normalizeCodexTools(reqBody) { @@ -171,7 +164,7 @@ func normalizeCodexModel(model string) string { return "gpt-5.3-codex" } if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") { - return "gpt-5.3" + return "gpt-5.3-codex" } if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") { return "gpt-5.1-codex-max" @@ -216,54 +209,9 @@ func getNormalizedCodexModel(modelID string) string { return "" } -func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string { - cacheDir := codexCachePath("") - if cacheDir == "" { - return "" - } - cacheFile := filepath.Join(cacheDir, cacheFileName) - metaFile := filepath.Join(cacheDir, metaFileName) - - var cachedContent string - if content, ok := readFile(cacheFile); ok { - cachedContent = content - } - - var meta opencodeCacheMetadata - if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" { - if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL { - return cachedContent - } - } - - content, etag, status, err := fetchWithETag(url, meta.ETag) - if err == nil && status == http.StatusNotModified && cachedContent != "" { - return cachedContent - } - if err == nil && status >= 200 && status < 300 && content != "" { - _ = writeFile(cacheFile, content) - meta = opencodeCacheMetadata{ - ETag: etag, - LastFetch: time.Now().UTC().Format(time.RFC3339), - LastChecked: time.Now().UnixMilli(), - } - _ = writeJSON(metaFile, meta) - return content - } - - return cachedContent -} - func getOpenCodeCodexHeader() string { - // 优先从 opencode 仓库缓存获取指令。 - opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json") - - // 若 opencode 指令可用,直接返回。 - if opencodeInstructions != "" { - return opencodeInstructions - } - - // 否则回退使用本地 Codex CLI 指令。 + // 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。 + // 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。 return getCodexCLIInstructions() } @@ -281,8 +229,8 @@ func GetCodexCLIInstructions() string { } // applyInstructions 处理 instructions 字段 -// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令) -// isCodexCLI=false: 优先使用 opencode 指令覆盖 +// isCodexCLI=true: 仅补充缺失的 instructions(使用内置 Codex CLI 指令) +// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖 func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { if isCodexCLI { return applyCodexCLIInstructions(reqBody) @@ -291,13 +239,13 @@ func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { } // applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions -// 仅在 instructions 为空时添加 opencode 指令 +// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源) func applyCodexCLIInstructions(reqBody map[string]any) bool { if !isInstructionsEmpty(reqBody) { return false // 已有有效 instructions,不修改 } - instructions := strings.TrimSpace(getOpenCodeCodexHeader()) + instructions := strings.TrimSpace(getCodexCLIInstructions()) if instructions != "" { reqBody["instructions"] = instructions return true @@ -306,8 +254,8 @@ func applyCodexCLIInstructions(reqBody map[string]any) bool { return false } -// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令 -// 优先使用 opencode 指令覆盖 +// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名) +// 优先使用内置 Codex CLI 指令覆盖 func applyOpenCodeInstructions(reqBody map[string]any) bool { instructions := strings.TrimSpace(getOpenCodeCodexHeader()) existingInstructions, _ := reqBody["instructions"].(string) @@ -489,85 +437,3 @@ func normalizeCodexTools(reqBody map[string]any) bool { return modified } - -func codexCachePath(filename string) string { - home, err := os.UserHomeDir() - if err != nil { - return "" - } - cacheDir := filepath.Join(home, ".opencode", "cache") - if filename == "" { - return cacheDir - } - return filepath.Join(cacheDir, filename) -} - -func readFile(path string) (string, bool) { - if path == "" { - return "", false - } - data, err := os.ReadFile(path) - if err != nil { - return "", false - } - return string(data), true -} - -func writeFile(path, content string) error { - if path == "" { - return fmt.Errorf("empty cache path") - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - return os.WriteFile(path, []byte(content), 0o644) -} - -func loadJSON(path string, target any) bool { - data, err := os.ReadFile(path) - if err != nil { - return false - } - if err := json.Unmarshal(data, target); err != nil { - return false - } - return true -} - -func writeJSON(path string, value any) error { - if path == "" { - return fmt.Errorf("empty json path") - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - data, err := json.Marshal(value) - if err != nil { - return err - } - return os.WriteFile(path, data, 0o644) -} - -func fetchWithETag(url, etag string) (string, string, int, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return "", "", 0, err - } - req.Header.Set("User-Agent", "sub2api-codex") - if etag != "" { - req.Header.Set("If-None-Match", etag) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", "", 0, err - } - defer func() { - _ = resp.Body.Close() - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", "", resp.StatusCode, err - } - return string(body), resp.Header.Get("etag"), resp.StatusCode, nil -} diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index cc0acafc..27093f6c 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -1,18 +1,13 @@ package service import ( - "encoding/json" - "os" - "path/filepath" "testing" - "time" "github.com/stretchr/testify/require" ) func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { // 续链场景:保留 item_reference 与 id,但不再强制 store=true。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.2", @@ -48,7 +43,6 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { // 续链场景:显式 store=false 不再强制为 true,保持 false。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -68,7 +62,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { // 显式 store=true 也会强制为 false。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -88,7 +81,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) { // 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -130,8 +122,6 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) { } func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) { - setupCodexCache(t) - reqBody := map[string]any{ "model": "gpt-5.1", "tools": []any{ @@ -162,7 +152,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { // 空 input 应保持为空且不触发异常。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -178,97 +167,39 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { func TestNormalizeCodexModel_Gpt53(t *testing.T) { cases := map[string]string{ - "gpt-5.3": "gpt-5.3", - "gpt-5.3-codex": "gpt-5.3-codex", - "gpt-5.3-codex-xhigh": "gpt-5.3-codex", - "gpt 5.3 codex": "gpt-5.3-codex", + "gpt-5.3": "gpt-5.3-codex", + "gpt-5.3-codex": "gpt-5.3-codex", + "gpt-5.3-codex-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt 5.3 codex": "gpt-5.3-codex", } for input, expected := range cases { require.Equal(t, expected, normalizeCodexModel(input)) } - } func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) { - // Codex CLI 场景:已有 instructions 时保持不变 - setupCodexCache(t) + // Codex CLI 场景:已有 instructions 时不修改 reqBody := map[string]any{ "model": "gpt-5.1", - "instructions": "user custom instructions", - "input": []any{}, + "instructions": "existing instructions", } - result := applyCodexOAuthTransform(reqBody, true) + result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true instructions, ok := reqBody["instructions"].(string) require.True(t, ok) - require.Equal(t, "user custom instructions", instructions) - // instructions 未变,但其他字段(如 store、stream)可能被修改 - require.True(t, result.Modified) -} - -func TestApplyCodexOAuthTransform_CodexCLI_AddsInstructionsWhenEmpty(t *testing.T) { - // Codex CLI 场景:无 instructions 时补充内置指令 - setupCodexCache(t) - - reqBody := map[string]any{ - "model": "gpt-5.1", - "input": []any{}, - } - - result := applyCodexOAuthTransform(reqBody, true) - - instructions, ok := reqBody["instructions"].(string) - require.True(t, ok) - require.NotEmpty(t, instructions) - require.True(t, result.Modified) -} - -func TestApplyCodexOAuthTransform_NonCodexCLI_UsesOpenCodeInstructions(t *testing.T) { - // 非 Codex CLI 场景:使用 opencode 指令(缓存中有 header) - setupCodexCache(t) - - reqBody := map[string]any{ - "model": "gpt-5.1", - "input": []any{}, - } - - result := applyCodexOAuthTransform(reqBody, false) - - instructions, ok := reqBody["instructions"].(string) - require.True(t, ok) - require.Equal(t, "header", instructions) // setupCodexCache 设置的缓存内容 - require.True(t, result.Modified) -} - -func setupCodexCache(t *testing.T) { - t.Helper() - - // 使用临时 HOME 避免触发网络拉取 header。 - // Windows 使用 USERPROFILE,Unix 使用 HOME。 - tempDir := t.TempDir() - t.Setenv("HOME", tempDir) - t.Setenv("USERPROFILE", tempDir) - - cacheDir := filepath.Join(tempDir, ".opencode", "cache") - require.NoError(t, os.MkdirAll(cacheDir, 0o755)) - require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644)) - - meta := map[string]any{ - "etag": "", - "lastFetch": time.Now().UTC().Format(time.RFC3339), - "lastChecked": time.Now().UnixMilli(), - } - data, err := json.Marshal(meta) - require.NoError(t, err) - require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644)) + require.Equal(t, "existing instructions", instructions) + // Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变 + _ = result } func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) { // Codex CLI 场景:无 instructions 时补充默认值 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -284,8 +215,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T } func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) { - // 非 Codex CLI 场景:使用 opencode 指令覆盖 - setupCodexCache(t) + // 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖 reqBody := map[string]any{ "model": "gpt-5.1", diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index fbe81cb4..f624d92a 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -10,20 +10,24 @@ import ( "errors" "fmt" "io" - "log" + "math/rand" "net/http" - "regexp" "sort" "strconv" "strings" + "sync" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" ) const ( @@ -32,20 +36,62 @@ const ( // OpenAI Platform API for API Key accounts (fallback) openaiPlatformAPIURL = "https://api.openai.com/v1/responses" openaiStickySessionTTL = time.Hour // 粘性会话TTL + codexCLIUserAgent = "codex_cli_rs/0.104.0" + // codex_cli_only 拒绝时单个请求头日志长度上限(字符) + codexCLIOnlyHeaderValueMaxBytes = 256 + + // OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。 + OpenAIParsedRequestBodyKey = "openai_parsed_request_body" + // OpenAI WS Mode 失败后的重连次数上限(不含首次尝试)。 + // 与 Codex 客户端保持一致:失败后最多重连 5 次。 + openAIWSReconnectRetryLimit = 5 + // OpenAI WS Mode 重连退避默认值(可由配置覆盖)。 + openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond + openAIWSRetryBackoffMaxDefault = 2 * time.Second + openAIWSRetryJitterRatioDefault = 0.2 ) -// openaiSSEDataRe matches SSE data lines with optional whitespace after colon. -// Some upstream APIs return non-standard "data:" without space (should be "data: "). -var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`) - -// OpenAI allowed headers whitelist (for non-OAuth accounts) +// OpenAI allowed headers whitelist (for non-passthrough). var openaiAllowedHeaders = map[string]bool{ - "accept-language": true, - "content-type": true, - "conversation_id": true, - "user-agent": true, - "originator": true, - "session_id": true, + "accept-language": true, + "content-type": true, + "conversation_id": true, + "user-agent": true, + "originator": true, + "session_id": true, + "x-codex-turn-state": true, + "x-codex-turn-metadata": true, +} + +// OpenAI passthrough allowed headers whitelist. +// 透传模式下仅放行这些低风险请求头,避免将非标准/环境噪声头传给上游触发风控。 +var openaiPassthroughAllowedHeaders = map[string]bool{ + "accept": true, + "accept-language": true, + "content-type": true, + "conversation_id": true, + "openai-beta": true, + "user-agent": true, + "originator": true, + "session_id": true, + "x-codex-turn-state": true, + "x-codex-turn-metadata": true, +} + +// codex_cli_only 拒绝时记录的请求头白名单(仅用于诊断日志,不参与上游透传) +var codexCLIOnlyDebugHeaderWhitelist = []string{ + "User-Agent", + "Content-Type", + "Accept", + "Accept-Language", + "OpenAI-Beta", + "Originator", + "Session_ID", + "Conversation_ID", + "X-Request-ID", + "X-Client-Request-ID", + "X-Forwarded-For", + "X-Real-IP", } // OpenAICodexUsageSnapshot represents Codex API usage limits from response headers @@ -163,10 +209,40 @@ type OpenAIForwardResult struct { // Stored for usage records display; nil means not provided / not applicable. ReasoningEffort *string Stream bool + OpenAIWSMode bool Duration time.Duration FirstTokenMs *int } +type OpenAIWSRetryMetricsSnapshot struct { + RetryAttemptsTotal int64 `json:"retry_attempts_total"` + RetryBackoffMsTotal int64 `json:"retry_backoff_ms_total"` + RetryExhaustedTotal int64 `json:"retry_exhausted_total"` + NonRetryableFastFallbackTotal int64 `json:"non_retryable_fast_fallback_total"` +} + +type OpenAICompatibilityFallbackMetricsSnapshot struct { + SessionHashLegacyReadFallbackTotal int64 `json:"session_hash_legacy_read_fallback_total"` + SessionHashLegacyReadFallbackHit int64 `json:"session_hash_legacy_read_fallback_hit"` + SessionHashLegacyDualWriteTotal int64 `json:"session_hash_legacy_dual_write_total"` + SessionHashLegacyReadHitRate float64 `json:"session_hash_legacy_read_hit_rate"` + + MetadataLegacyFallbackIsMaxTokensOneHaikuTotal int64 `json:"metadata_legacy_fallback_is_max_tokens_one_haiku_total"` + MetadataLegacyFallbackThinkingEnabledTotal int64 `json:"metadata_legacy_fallback_thinking_enabled_total"` + MetadataLegacyFallbackPrefetchedStickyAccount int64 `json:"metadata_legacy_fallback_prefetched_sticky_account_total"` + MetadataLegacyFallbackPrefetchedStickyGroup int64 `json:"metadata_legacy_fallback_prefetched_sticky_group_total"` + MetadataLegacyFallbackSingleAccountRetryTotal int64 `json:"metadata_legacy_fallback_single_account_retry_total"` + MetadataLegacyFallbackAccountSwitchCountTotal int64 `json:"metadata_legacy_fallback_account_switch_count_total"` + MetadataLegacyFallbackTotal int64 `json:"metadata_legacy_fallback_total"` +} + +type openAIWSRetryMetrics struct { + retryAttempts atomic.Int64 + retryBackoffMs atomic.Int64 + retryExhausted atomic.Int64 + nonRetryableFastFallback atomic.Int64 +} + // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { accountRepo AccountRepository @@ -175,6 +251,7 @@ type OpenAIGatewayService struct { userSubRepo UserSubscriptionRepository cache GatewayCache cfg *config.Config + codexDetector CodexClientRestrictionDetector schedulerSnapshot *SchedulerSnapshotService concurrencyService *ConcurrencyService billingService *BillingService @@ -184,6 +261,19 @@ type OpenAIGatewayService struct { deferredService *DeferredService openAITokenProvider *OpenAITokenProvider toolCorrector *CodexToolCorrector + openaiWSResolver OpenAIWSProtocolResolver + + openaiWSPoolOnce sync.Once + openaiWSStateStoreOnce sync.Once + openaiSchedulerOnce sync.Once + openaiWSPool *openAIWSConnPool + openaiWSStateStore OpenAIWSStateStore + openaiScheduler OpenAIAccountScheduler + openaiAccountStats *openAIAccountRuntimeStats + + openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time + openaiWSRetryMetrics openAIWSRetryMetrics + responseHeaderFilter *responseheaders.CompiledHeaderFilter } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -203,23 +293,587 @@ func NewOpenAIGatewayService( deferredService *DeferredService, openAITokenProvider *OpenAITokenProvider, ) *OpenAIGatewayService { - return &OpenAIGatewayService{ - accountRepo: accountRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - httpUpstream: httpUpstream, - deferredService: deferredService, - openAITokenProvider: openAITokenProvider, - toolCorrector: NewCodexToolCorrector(), + svc := &OpenAIGatewayService{ + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + cache: cache, + cfg: cfg, + codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + httpUpstream: httpUpstream, + deferredService: deferredService, + openAITokenProvider: openAITokenProvider, + toolCorrector: NewCodexToolCorrector(), + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + responseHeaderFilter: compileResponseHeaderFilter(cfg), } + svc.logOpenAIWSModeBootstrap() + return svc +} + +// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。 +// 应在应用优雅关闭时调用。 +func (s *OpenAIGatewayService) CloseOpenAIWSPool() { + if s != nil && s.openaiWSPool != nil { + s.openaiWSPool.Close() + } +} + +func (s *OpenAIGatewayService) logOpenAIWSModeBootstrap() { + if s == nil || s.cfg == nil { + return + } + wsCfg := s.cfg.Gateway.OpenAIWS + logOpenAIWSModeInfo( + "bootstrap enabled=%v oauth_enabled=%v apikey_enabled=%v force_http=%v responses_websockets_v2=%v responses_websockets=%v payload_log_sample_rate=%.3f event_flush_batch_size=%d event_flush_interval_ms=%d prewarm_cooldown_ms=%d retry_backoff_initial_ms=%d retry_backoff_max_ms=%d retry_jitter_ratio=%.3f retry_total_budget_ms=%d ws_read_limit_bytes=%d", + wsCfg.Enabled, + wsCfg.OAuthEnabled, + wsCfg.APIKeyEnabled, + wsCfg.ForceHTTP, + wsCfg.ResponsesWebsocketsV2, + wsCfg.ResponsesWebsockets, + wsCfg.PayloadLogSampleRate, + wsCfg.EventFlushBatchSize, + wsCfg.EventFlushIntervalMS, + wsCfg.PrewarmCooldownMS, + wsCfg.RetryBackoffInitialMS, + wsCfg.RetryBackoffMaxMS, + wsCfg.RetryJitterRatio, + wsCfg.RetryTotalBudgetMS, + openAIWSMessageReadLimitBytes, + ) +} + +func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRestrictionDetector { + if s != nil && s.codexDetector != nil { + return s.codexDetector + } + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + return NewOpenAICodexClientRestrictionDetector(cfg) +} + +func (s *OpenAIGatewayService) getOpenAIWSProtocolResolver() OpenAIWSProtocolResolver { + if s != nil && s.openaiWSResolver != nil { + return s.openaiWSResolver + } + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + return NewOpenAIWSProtocolResolver(cfg) +} + +func classifyOpenAIWSReconnectReason(err error) (string, bool) { + if err == nil { + return "", false + } + var fallbackErr *openAIWSFallbackError + if !errors.As(err, &fallbackErr) || fallbackErr == nil { + return "", false + } + reason := strings.TrimSpace(fallbackErr.Reason) + if reason == "" { + return "", false + } + + baseReason := strings.TrimPrefix(reason, "prewarm_") + + switch baseReason { + case "policy_violation", + "message_too_big", + "upgrade_required", + "ws_unsupported", + "auth_failed", + "previous_response_not_found": + return reason, false + } + + switch baseReason { + case "read_event", + "write_request", + "write", + "acquire_timeout", + "acquire_conn", + "conn_queue_full", + "dial_failed", + "upstream_5xx", + "event_error", + "error_event", + "upstream_error_event", + "ws_connection_limit_reached", + "missing_final_response": + return reason, true + default: + return reason, false + } +} + +func resolveOpenAIWSFallbackErrorResponse(err error) (statusCode int, errType string, clientMessage string, upstreamMessage string, ok bool) { + if err == nil { + return 0, "", "", "", false + } + var fallbackErr *openAIWSFallbackError + if !errors.As(err, &fallbackErr) || fallbackErr == nil { + return 0, "", "", "", false + } + + reason := strings.TrimSpace(fallbackErr.Reason) + reason = strings.TrimPrefix(reason, "prewarm_") + if reason == "" { + return 0, "", "", "", false + } + + var dialErr *openAIWSDialError + if fallbackErr.Err != nil && errors.As(fallbackErr.Err, &dialErr) && dialErr != nil { + if dialErr.StatusCode > 0 { + statusCode = dialErr.StatusCode + } + if dialErr.Err != nil { + upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(dialErr.Err.Error())) + } + } + + switch reason { + case "previous_response_not_found": + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + errType = "invalid_request_error" + if upstreamMessage == "" { + upstreamMessage = "previous response not found" + } + case "upgrade_required": + if statusCode == 0 { + statusCode = http.StatusUpgradeRequired + } + case "ws_unsupported": + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + case "auth_failed": + if statusCode == 0 { + statusCode = http.StatusUnauthorized + } + case "upstream_rate_limited": + if statusCode == 0 { + statusCode = http.StatusTooManyRequests + } + default: + if statusCode == 0 { + return 0, "", "", "", false + } + } + + if upstreamMessage == "" && fallbackErr.Err != nil { + upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(fallbackErr.Err.Error())) + } + if upstreamMessage == "" { + switch reason { + case "upgrade_required": + upstreamMessage = "upstream websocket upgrade required" + case "ws_unsupported": + upstreamMessage = "upstream websocket not supported" + case "auth_failed": + upstreamMessage = "upstream authentication failed" + case "upstream_rate_limited": + upstreamMessage = "upstream rate limit exceeded, please retry later" + default: + upstreamMessage = "Upstream request failed" + } + } + + if errType == "" { + if statusCode == http.StatusTooManyRequests { + errType = "rate_limit_error" + } else { + errType = "upstream_error" + } + } + clientMessage = upstreamMessage + return statusCode, errType, clientMessage, upstreamMessage, true +} + +func (s *OpenAIGatewayService) writeOpenAIWSFallbackErrorResponse(c *gin.Context, account *Account, wsErr error) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(wsErr) + if !ok { + return false + } + if strings.TrimSpace(clientMessage) == "" { + clientMessage = "Upstream request failed" + } + if strings.TrimSpace(upstreamMessage) == "" { + upstreamMessage = clientMessage + } + + setOpsUpstreamError(c, statusCode, upstreamMessage, "") + if account != nil { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: statusCode, + Kind: "ws_error", + Message: upstreamMessage, + }) + } + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": clientMessage, + }, + }) + return true +} + +func (s *OpenAIGatewayService) openAIWSRetryBackoff(attempt int) time.Duration { + if attempt <= 0 { + return 0 + } + + initial := openAIWSRetryBackoffInitialDefault + maxBackoff := openAIWSRetryBackoffMaxDefault + jitterRatio := openAIWSRetryJitterRatioDefault + if s != nil && s.cfg != nil { + wsCfg := s.cfg.Gateway.OpenAIWS + if wsCfg.RetryBackoffInitialMS > 0 { + initial = time.Duration(wsCfg.RetryBackoffInitialMS) * time.Millisecond + } + if wsCfg.RetryBackoffMaxMS > 0 { + maxBackoff = time.Duration(wsCfg.RetryBackoffMaxMS) * time.Millisecond + } + if wsCfg.RetryJitterRatio >= 0 { + jitterRatio = wsCfg.RetryJitterRatio + } + } + if initial <= 0 { + return 0 + } + if maxBackoff <= 0 { + maxBackoff = initial + } + if maxBackoff < initial { + maxBackoff = initial + } + if jitterRatio < 0 { + jitterRatio = 0 + } + if jitterRatio > 1 { + jitterRatio = 1 + } + + shift := attempt - 1 + if shift < 0 { + shift = 0 + } + backoff := initial + if shift > 0 { + backoff = initial * time.Duration(1< maxBackoff { + backoff = maxBackoff + } + if jitterRatio <= 0 { + return backoff + } + jitter := time.Duration(float64(backoff) * jitterRatio) + if jitter <= 0 { + return backoff + } + delta := time.Duration(rand.Int63n(int64(jitter)*2+1)) - jitter + withJitter := backoff + delta + if withJitter < 0 { + return 0 + } + return withJitter +} + +func (s *OpenAIGatewayService) openAIWSRetryTotalBudget() time.Duration { + if s != nil && s.cfg != nil { + ms := s.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS + if ms <= 0 { + return 0 + } + return time.Duration(ms) * time.Millisecond + } + return 0 +} + +func (s *OpenAIGatewayService) recordOpenAIWSRetryAttempt(backoff time.Duration) { + if s == nil { + return + } + s.openaiWSRetryMetrics.retryAttempts.Add(1) + if backoff > 0 { + s.openaiWSRetryMetrics.retryBackoffMs.Add(backoff.Milliseconds()) + } +} + +func (s *OpenAIGatewayService) recordOpenAIWSRetryExhausted() { + if s == nil { + return + } + s.openaiWSRetryMetrics.retryExhausted.Add(1) +} + +func (s *OpenAIGatewayService) recordOpenAIWSNonRetryableFastFallback() { + if s == nil { + return + } + s.openaiWSRetryMetrics.nonRetryableFastFallback.Add(1) +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSRetryMetrics() OpenAIWSRetryMetricsSnapshot { + if s == nil { + return OpenAIWSRetryMetricsSnapshot{} + } + return OpenAIWSRetryMetricsSnapshot{ + RetryAttemptsTotal: s.openaiWSRetryMetrics.retryAttempts.Load(), + RetryBackoffMsTotal: s.openaiWSRetryMetrics.retryBackoffMs.Load(), + RetryExhaustedTotal: s.openaiWSRetryMetrics.retryExhausted.Load(), + NonRetryableFastFallbackTotal: s.openaiWSRetryMetrics.nonRetryableFastFallback.Load(), + } +} + +func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMetricsSnapshot { + legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal := openAIStickyCompatStats() + isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount := RequestMetadataFallbackStats() + + readHitRate := float64(0) + if legacyReadFallbackTotal > 0 { + readHitRate = float64(legacyReadFallbackHit) / float64(legacyReadFallbackTotal) + } + metadataFallbackTotal := isMaxTokensOneHaiku + thinkingEnabled + prefetchedStickyAccount + prefetchedStickyGroup + singleAccountRetry + accountSwitchCount + + return OpenAICompatibilityFallbackMetricsSnapshot{ + SessionHashLegacyReadFallbackTotal: legacyReadFallbackTotal, + SessionHashLegacyReadFallbackHit: legacyReadFallbackHit, + SessionHashLegacyDualWriteTotal: legacyDualWriteTotal, + SessionHashLegacyReadHitRate: readHitRate, + + MetadataLegacyFallbackIsMaxTokensOneHaikuTotal: isMaxTokensOneHaiku, + MetadataLegacyFallbackThinkingEnabledTotal: thinkingEnabled, + MetadataLegacyFallbackPrefetchedStickyAccount: prefetchedStickyAccount, + MetadataLegacyFallbackPrefetchedStickyGroup: prefetchedStickyGroup, + MetadataLegacyFallbackSingleAccountRetryTotal: singleAccountRetry, + MetadataLegacyFallbackAccountSwitchCountTotal: accountSwitchCount, + MetadataLegacyFallbackTotal: metadataFallbackTotal, + } +} + +func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult { + return s.getCodexClientRestrictionDetector().Detect(c, account) +} + +func getAPIKeyIDFromContext(c *gin.Context) int64 { + if c == nil { + return 0 + } + v, exists := c.Get("api_key") + if !exists { + return 0 + } + apiKey, ok := v.(*APIKey) + if !ok || apiKey == nil { + return 0 + } + return apiKey.ID +} + +func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult, body []byte) { + if !result.Enabled { + return + } + if ctx == nil { + ctx = context.Background() + } + accountID := int64(0) + if account != nil { + accountID = account.ID + } + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.Bool("codex_cli_only_enabled", result.Enabled), + zap.Bool("codex_official_client_match", result.Matched), + zap.String("reject_reason", result.Reason), + } + if apiKeyID > 0 { + fields = append(fields, zap.Int64("api_key_id", apiKeyID)) + } + if !result.Matched { + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) + } + log := logger.FromContext(ctx).With(fields...) + if result.Matched { + return + } + log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求") +} + +func appendCodexCLIOnlyRejectedRequestFields(fields []zap.Field, c *gin.Context, body []byte) []zap.Field { + if c == nil || c.Request == nil { + return fields + } + + req := c.Request + requestModel, requestStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + fields = append(fields, + zap.String("request_method", strings.TrimSpace(req.Method)), + zap.String("request_path", strings.TrimSpace(req.URL.Path)), + zap.String("request_query", strings.TrimSpace(req.URL.RawQuery)), + zap.String("request_host", strings.TrimSpace(req.Host)), + zap.String("request_client_ip", strings.TrimSpace(c.ClientIP())), + zap.String("request_remote_addr", strings.TrimSpace(req.RemoteAddr)), + zap.String("request_user_agent", strings.TrimSpace(req.Header.Get("User-Agent"))), + zap.String("request_content_type", strings.TrimSpace(req.Header.Get("Content-Type"))), + zap.Int64("request_content_length", req.ContentLength), + zap.Bool("request_stream", requestStream), + ) + if requestModel != "" { + fields = append(fields, zap.String("request_model", requestModel)) + } + if promptCacheKey != "" { + fields = append(fields, zap.String("request_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey))) + } + + if headers := snapshotCodexCLIOnlyHeaders(req.Header); len(headers) > 0 { + fields = append(fields, zap.Any("request_headers", headers)) + } + fields = append(fields, zap.Int("request_body_size", len(body))) + return fields +} + +func snapshotCodexCLIOnlyHeaders(header http.Header) map[string]string { + if len(header) == 0 { + return nil + } + result := make(map[string]string, len(codexCLIOnlyDebugHeaderWhitelist)) + for _, key := range codexCLIOnlyDebugHeaderWhitelist { + value := strings.TrimSpace(header.Get(key)) + if value == "" { + continue + } + result[strings.ToLower(key)] = truncateString(value, codexCLIOnlyHeaderValueMaxBytes) + } + return result +} + +func hashSensitiveValueForLog(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + sum := sha256.Sum256([]byte(value)) + return hex.EncodeToString(sum[:8]) +} + +func logOpenAIInstructionsRequiredDebug( + ctx context.Context, + c *gin.Context, + account *Account, + upstreamStatusCode int, + upstreamMsg string, + requestBody []byte, + upstreamBody []byte, +) { + msg := strings.TrimSpace(upstreamMsg) + if !isOpenAIInstructionsRequiredError(upstreamStatusCode, msg, upstreamBody) { + return + } + if ctx == nil { + ctx = context.Background() + } + + accountID := int64(0) + accountName := "" + if account != nil { + accountID = account.ID + accountName = strings.TrimSpace(account.Name) + } + + userAgent := "" + if c != nil { + userAgent = strings.TrimSpace(c.GetHeader("User-Agent")) + } + + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.String("account_name", accountName), + zap.Int("upstream_status_code", upstreamStatusCode), + zap.String("upstream_error_message", msg), + zap.String("request_user_agent", userAgent), + zap.Bool("codex_official_client_match", openai.IsCodexCLIRequest(userAgent)), + } + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody) + + logger.FromContext(ctx).With(fields...).Warn("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查") +} + +func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool { + if upstreamStatusCode != http.StatusBadRequest { + return false + } + + hasInstructionRequired := func(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if strings.Contains(lower, "instructions are required") { + return true + } + if strings.Contains(lower, "required parameter: 'instructions'") { + return true + } + if strings.Contains(lower, "required parameter: instructions") { + return true + } + if strings.Contains(lower, "missing required parameter") && strings.Contains(lower, "instructions") { + return true + } + return strings.Contains(lower, "instruction") && strings.Contains(lower, "required") + } + + if hasInstructionRequired(upstreamMsg) { + return true + } + if len(upstreamBody) == 0 { + return false + } + + errMsg := gjson.GetBytes(upstreamBody, "error.message").String() + errMsgLower := strings.ToLower(strings.TrimSpace(errMsg)) + errCode := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.code").String())) + errParam := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.param").String())) + errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.type").String())) + + if errParam == "instructions" { + return true + } + if hasInstructionRequired(errMsg) { + return true + } + if strings.Contains(errCode, "missing_required_parameter") && strings.Contains(errMsgLower, "instructions") { + return true + } + if strings.Contains(errType, "invalid_request") && strings.Contains(errMsgLower, "instructions") && strings.Contains(errMsgLower, "required") { + return true + } + + return false } // GenerateSessionHash generates a sticky-session hash for OpenAI requests. @@ -228,7 +882,7 @@ func NewOpenAIGatewayService( // 1. Header: session_id // 2. Header: conversation_id // 3. Body: prompt_cache_key (opencode) -func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string { +func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string { if c == nil { return "" } @@ -237,17 +891,35 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[s if sessionID == "" { sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) } - if sessionID == "" && reqBody != nil { - if v, ok := reqBody["prompt_cache_key"].(string); ok { - sessionID = strings.TrimSpace(v) - } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) } if sessionID == "" { return "" } - hash := sha256.Sum256([]byte(sessionID)) - return hex.EncodeToString(hash[:]) + currentHash, legacyHash := deriveOpenAISessionHashes(sessionID) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash +} + +// GenerateSessionHashWithFallback 先按常规信号生成会话哈希; +// 当未携带 session_id/conversation_id/prompt_cache_key 时,使用 fallbackSeed 生成稳定哈希。 +// 该方法用于 WS ingress,避免会话信号缺失时发生跨账号漂移。 +func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, body []byte, fallbackSeed string) string { + sessionHash := s.GenerateSessionHash(c, body) + if sessionHash != "" { + return sessionHash + } + + seed := strings.TrimSpace(fallbackSeed) + if seed == "" { + return "" + } + + currentHash, legacyHash := deriveOpenAISessionHashes(seed) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash } // BindStickySession sets session -> account binding with standard TTL. @@ -255,7 +927,11 @@ func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *i if sessionHash == "" || accountID <= 0 { return nil } - return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL) + ttl := openaiStickySessionTTL + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 { + ttl = time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second + } + return s.setStickySessionAccountID(ctx, groupID, sessionHash, accountID, ttl) } // SelectAccount selects an OpenAI account with sticky session support @@ -271,11 +947,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. // SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。 func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - cacheKey := "openai:" + sessionHash + return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, 0) +} +func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) { // 1. 尝试粘性会话命中 // Try sticky session hit - if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil { + if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil { return account, nil } @@ -300,7 +978,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // 4. 设置粘性会话绑定 // Set sticky session binding if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) } return selected, nil @@ -311,14 +989,18 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // // tryStickySessionHit attempts to get account from sticky session. // Returns account if hit and usable; clears session and returns nil if account is unavailable. -func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account { +func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) *Account { if sessionHash == "" { return nil } - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey) - if err != nil || accountID <= 0 { - return nil + accountID := stickyAccountID + if accountID <= 0 { + var err error + accountID, err = s.getStickySessionAccountID(ctx, groupID, sessionHash) + if err != nil || accountID <= 0 { + return nil + } } if _, excluded := excludedIDs[accountID]; excluded { @@ -333,7 +1015,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 检查账号是否需要清理粘性会话 // Check if sticky session should be cleared if shouldClearStickySession(account, requestedModel) { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey) + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil } @@ -348,7 +1030,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 刷新会话 TTL 并返回账号 // Refresh session TTL and return account - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL) + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) return account } @@ -434,12 +1116,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex cfg := s.schedulingConfig() var stickyAccountID int64 if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil { + if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { stickyAccountID = accountID } } if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) + account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID) if err != nil { return nil, err } @@ -494,19 +1176,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex // ============ Layer 1: Sticky session ============ if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { + accountID := stickyAccountID + if accountID > 0 && !isExcluded(accountID) { account, err := s.getSchedulableAccount(ctx, accountID) if err == nil { clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, @@ -570,7 +1252,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, acc.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: acc, @@ -580,10 +1262,6 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex } } } else { - type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo - } var available []accountWithLoad for _, acc := range candidates { loadInfo := loadMap[acc.ID] @@ -618,12 +1296,13 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex return a.account.LastUsedAt.Before(*b.account.LastUsedAt) } }) + shuffleWithinSortGroups(available) for _, item := range available { result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, item.account.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: item.account, @@ -747,43 +1426,159 @@ func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, re func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) { startTime := time.Now() - // Parse request body once (avoid multiple parse/serialize cycles) - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err != nil { - return nil, fmt.Errorf("parse request: %w", err) + restrictionResult := s.detectCodexClientRestriction(c, account) + apiKeyID := getAPIKeyIDFromContext(c) + logCodexCLIOnlyDetection(ctx, c, account, apiKeyID, restrictionResult, body) + if restrictionResult.Enabled && !restrictionResult.Matched { + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "forbidden_error", + "message": "This account only allows Codex official clients", + }, + }) + return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed") } - // Extract model and stream from parsed body - reqModel, _ := reqBody["model"].(string) - reqStream, _ := reqBody["stream"].(bool) - promptCacheKey := "" - if v, ok := reqBody["prompt_cache_key"].(string); ok { - promptCacheKey = strings.TrimSpace(v) + originalBody := body + reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + originalModel := reqModel + + isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) + clientTransport := GetOpenAIClientTransport(c) + // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。 + wsDecision = resolveOpenAIWSDecisionByClientTransport(wsDecision, clientTransport) + if c != nil { + c.Set("openai_ws_transport_decision", string(wsDecision.Transport)) + c.Set("openai_ws_transport_reason", wsDecision.Reason) + } + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + logOpenAIWSModeDebug( + "selected account_id=%d account_type=%s transport=%s reason=%s model=%s stream=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + normalizeOpenAIWSLogValue(wsDecision.Reason), + reqModel, + reqStream, + ) + } + // 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。 + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket { + if c != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.", + }, + }) + } + return nil, errors.New("openai ws v1 is temporarily unsupported; use ws v2") + } + passthroughEnabled := account.IsOpenAIPassthroughEnabled() + if passthroughEnabled { + // 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。 + reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel) + return s.forwardOpenAIPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime) + } + + reqBody, err := getOpenAIRequestBodyMap(c, body) + if err != nil { + return nil, err + } + + if v, ok := reqBody["model"].(string); ok { + reqModel = v + originalModel = reqModel + } + if v, ok := reqBody["stream"].(bool); ok { + reqStream = v + } + if promptCacheKey == "" { + if v, ok := reqBody["prompt_cache_key"].(string); ok { + promptCacheKey = strings.TrimSpace(v) + } } // Track if body needs re-serialization bodyModified := false - originalModel := reqModel + // 单字段补丁快速路径:只要整个变更集最终可归约为同一路径的 set/delete,就避免全量 Marshal。 + patchDisabled := false + patchHasOp := false + patchDelete := false + patchPath := "" + var patchValue any + markPatchSet := func(path string, value any) { + if strings.TrimSpace(path) == "" { + patchDisabled = true + return + } + if patchDisabled { + return + } + if !patchHasOp { + patchHasOp = true + patchDelete = false + patchPath = path + patchValue = value + return + } + if patchDelete || patchPath != path { + patchDisabled = true + return + } + patchValue = value + } + markPatchDelete := func(path string) { + if strings.TrimSpace(path) == "" { + patchDisabled = true + return + } + if patchDisabled { + return + } + if !patchHasOp { + patchHasOp = true + patchDelete = true + patchPath = path + return + } + if !patchDelete || patchPath != path { + patchDisabled = true + } + } + disablePatch := func() { + patchDisabled = true + } - isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) + // 非透传模式下,保持历史行为:非 Codex CLI 请求在 instructions 为空时注入默认指令。 + if !isCodexCLI && isInstructionsEmpty(reqBody) { + if instructions := strings.TrimSpace(GetOpenCodeInstructions()); instructions != "" { + reqBody["instructions"] = instructions + bodyModified = true + markPatchSet("instructions", instructions) + } + } // 对所有请求执行模型映射(包含 Codex CLI)。 mappedModel := account.GetMappedModel(reqModel) if mappedModel != reqModel { - log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) reqBody["model"] = mappedModel bodyModified = true + markPatchSet("model", mappedModel) } // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 if model, ok := reqBody["model"].(string); ok { normalizedModel := normalizeCodexModel(model) if normalizedModel != "" && normalizedModel != model { - log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", model, normalizedModel, account.Name, account.Type, isCodexCLI) reqBody["model"] = normalizedModel mappedModel = normalizedModel bodyModified = true + markPatchSet("model", normalizedModel) } } @@ -792,7 +1587,8 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" { reasoning["effort"] = "none" bodyModified = true - log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name) + markPatchSet("reasoning.effort", "none") + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name) } } @@ -800,6 +1596,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI) if codexResult.Modified { bodyModified = true + disablePatch() } if codexResult.NormalizedModel != "" { mappedModel = codexResult.NormalizedModel @@ -819,22 +1616,27 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if account.Type == AccountTypeAPIKey { delete(reqBody, "max_output_tokens") bodyModified = true + markPatchDelete("max_output_tokens") } case PlatformAnthropic: // For Anthropic (Claude), convert to max_tokens delete(reqBody, "max_output_tokens") + markPatchDelete("max_output_tokens") if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens { reqBody["max_tokens"] = maxOutputTokens + disablePatch() } bodyModified = true case PlatformGemini: // For Gemini, remove (will be handled by Gemini-specific transform) delete(reqBody, "max_output_tokens") bodyModified = true + markPatchDelete("max_output_tokens") default: // For unknown platforms, remove to be safe delete(reqBody, "max_output_tokens") bodyModified = true + markPatchDelete("max_output_tokens") } } @@ -843,24 +1645,51 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI { delete(reqBody, "max_completion_tokens") bodyModified = true + markPatchDelete("max_completion_tokens") } } // Remove unsupported fields (not supported by upstream OpenAI API) - for _, unsupportedField := range []string{"prompt_cache_retention", "safety_identifier", "previous_response_id"} { + unsupportedFields := []string{"prompt_cache_retention", "safety_identifier"} + for _, unsupportedField := range unsupportedFields { if _, has := reqBody[unsupportedField]; has { delete(reqBody, unsupportedField) bodyModified = true + markPatchDelete(unsupportedField) } } } + // 仅在 WSv2 模式保留 previous_response_id,其他模式(HTTP/WSv1)统一过滤。 + // 注意:该规则同样适用于 Codex CLI 请求,避免 WSv1 向上游透传不支持字段。 + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + if _, has := reqBody["previous_response_id"]; has { + delete(reqBody, "previous_response_id") + bodyModified = true + markPatchDelete("previous_response_id") + } + } + // Re-serialize body only if modified if bodyModified { - var err error - body, err = json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("serialize request body: %w", err) + serializedByPatch := false + if !patchDisabled && patchHasOp { + var patchErr error + if patchDelete { + body, patchErr = sjson.DeleteBytes(body, patchPath) + } else { + body, patchErr = sjson.SetBytes(body, patchPath, patchValue) + } + if patchErr == nil { + serializedByPatch = true + } + } + if !serializedByPatch { + var marshalErr error + body, marshalErr = json.Marshal(reqBody) + if marshalErr != nil { + return nil, fmt.Errorf("serialize request body: %w", marshalErr) + } } } @@ -870,6 +1699,184 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco return nil, err } + // Capture upstream request body for ops retry of this attempt. + setOpsUpstreamRequestBody(c, body) + + // 命中 WS 时仅走 WebSocket Mode;不再自动回退 HTTP。 + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + wsReqBody := reqBody + if len(reqBody) > 0 { + wsReqBody = make(map[string]any, len(reqBody)) + for k, v := range reqBody { + wsReqBody[k] = v + } + } + _, hasPreviousResponseID := wsReqBody["previous_response_id"] + logOpenAIWSModeDebug( + "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v", + account.ID, + account.Type, + mappedModel, + reqStream, + hasPreviousResponseID, + ) + maxAttempts := openAIWSReconnectRetryLimit + 1 + wsAttempts := 0 + var wsResult *OpenAIForwardResult + var wsErr error + wsLastFailureReason := "" + wsPrevResponseRecoveryTried := false + recoverPrevResponseNotFound := func(attempt int) bool { + if wsPrevResponseRecoveryTried { + return false + } + previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id") + if previousResponseID == "" { + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=missing_previous_response_id previous_response_id_present=false", + account.ID, + attempt, + ) + return false + } + if HasFunctionCallOutput(wsReqBody) { + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=has_function_call_output previous_response_id_present=true", + account.ID, + attempt, + ) + return false + } + delete(wsReqBody, "previous_response_id") + wsPrevResponseRecoveryTried = true + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery account_id=%d attempt=%d action=drop_previous_response_id retry=1 previous_response_id=%s previous_response_id_kind=%s", + account.ID, + attempt, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)), + ) + return true + } + retryBudget := s.openAIWSRetryTotalBudget() + retryStartedAt := time.Now() + wsRetryLoop: + for attempt := 1; attempt <= maxAttempts; attempt++ { + wsAttempts = attempt + wsResult, wsErr = s.forwardOpenAIWSV2( + ctx, + c, + account, + wsReqBody, + token, + wsDecision, + isCodexCLI, + reqStream, + originalModel, + mappedModel, + startTime, + attempt, + wsLastFailureReason, + ) + if wsErr == nil { + break + } + if c != nil && c.Writer != nil && c.Writer.Written() { + break + } + + reason, retryable := classifyOpenAIWSReconnectReason(wsErr) + if reason != "" { + wsLastFailureReason = reason + } + // previous_response_not_found 说明续链锚点不可用: + // 对非 function_call_output 场景,允许一次“去掉 previous_response_id 后重放”。 + if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) { + continue + } + if retryable && attempt < maxAttempts { + backoff := s.openAIWSRetryBackoff(attempt) + if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget { + s.recordOpenAIWSRetryExhausted() + logOpenAIWSModeInfo( + "reconnect_budget_exhausted account_id=%d attempts=%d max_retries=%d reason=%s elapsed_ms=%d budget_ms=%d", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + time.Since(retryStartedAt).Milliseconds(), + retryBudget.Milliseconds(), + ) + break + } + s.recordOpenAIWSRetryAttempt(backoff) + logOpenAIWSModeInfo( + "reconnect_retry account_id=%d retry=%d max_retries=%d reason=%s backoff_ms=%d", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + backoff.Milliseconds(), + ) + if backoff > 0 { + timer := time.NewTimer(backoff) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + wsErr = wrapOpenAIWSFallback("retry_backoff_canceled", ctx.Err()) + break wsRetryLoop + case <-timer.C: + } + } + continue + } + if retryable { + s.recordOpenAIWSRetryExhausted() + logOpenAIWSModeInfo( + "reconnect_exhausted account_id=%d attempts=%d max_retries=%d reason=%s", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + ) + } else if reason != "" { + s.recordOpenAIWSNonRetryableFastFallback() + logOpenAIWSModeInfo( + "reconnect_stop account_id=%d attempt=%d reason=%s", + account.ID, + attempt, + normalizeOpenAIWSLogValue(reason), + ) + } + break + } + if wsErr == nil { + firstTokenMs := int64(0) + hasFirstTokenMs := wsResult != nil && wsResult.FirstTokenMs != nil + if hasFirstTokenMs { + firstTokenMs = int64(*wsResult.FirstTokenMs) + } + requestID := "" + if wsResult != nil { + requestID = strings.TrimSpace(wsResult.RequestID) + } + logOpenAIWSModeDebug( + "forward_succeeded account_id=%d request_id=%s stream=%v has_first_token_ms=%v first_token_ms=%d ws_attempts=%d", + account.ID, + requestID, + reqStream, + hasFirstTokenMs, + firstTokenMs, + wsAttempts, + ) + return wsResult, nil + } + s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) + return nil, wsErr + } + // Build upstream request upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) if err != nil { @@ -882,13 +1889,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco proxyURL = account.Proxy.URL() } - // Capture upstream request body for ops retry of this attempt. - if c != nil { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - // Send request + upstreamStart := time.Now() resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) if err != nil { // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). safeErr := sanitizeUpstreamErrorMessage(err.Error()) @@ -942,7 +1946,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco s.handleFailoverSideEffects(ctx, resp, account) return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } - return s.handleErrorResponse(ctx, resp, c, account) + return s.handleErrorResponse(ctx, resp, c, account, body) } // Handle normal response @@ -969,6 +1973,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } + if usage == nil { + usage = &OpenAIUsage{} + } + reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) return &OpenAIForwardResult{ @@ -977,11 +1985,576 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco Model: originalModel, ReasoningEffort: reasoningEffort, Stream: reqStream, + OpenAIWSMode: false, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil } +func (s *OpenAIGatewayService) forwardOpenAIPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + reqModel string, + reasoningEffort *string, + reqStream bool, + startTime time.Time, +) (*OpenAIForwardResult, error) { + if account != nil && account.Type == AccountTypeOAuth { + if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" { + rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field" + setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: http.StatusForbidden, + Passthrough: true, + Kind: "request_error", + Message: rejectMsg, + Detail: rejectReason, + }) + logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body) + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "forbidden_error", + "message": rejectMsg, + }, + }) + return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason) + } + + normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body) + if err != nil { + return nil, err + } + if normalized { + body = normalizedBody + reqStream = true + } + } + + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", + account.ID, + account.Name, + account.Type, + reqModel, + reqStream, + ) + if reqStream && c != nil && c.Request != nil { + if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 { + streamWarnLogger := logger.FromContext(ctx).With( + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", account.ID), + zap.Strings("timeout_headers", timeoutHeaders), + ) + if s.isOpenAIPassthroughTimeoutHeadersAllowed() { + streamWarnLogger.Warn("OpenAI passthrough 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流") + } else { + streamWarnLogger.Warn("OpenAI passthrough 检测到超时相关请求头,将按配置过滤以降低断流风险") + } + } + } + + // Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token) + if err != nil { + return nil, err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + setOpsUpstreamRequestBody(c, body) + if c != nil { + c.Set("openai_passthrough", true) + } + + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + // 透传模式不做 failover(避免改变原始上游语义),按上游原样返回错误响应。 + return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body) + } + + var usage *OpenAIUsage + var firstTokenMs *int + if reqStream { + result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime) + if err != nil { + return nil, err + } + usage = result.usage + firstTokenMs = result.firstTokenMs + } else { + usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c) + if err != nil { + return nil, err + } + } + + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } + + if usage == nil { + usage = &OpenAIUsage{} + } + + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: reqModel, + ReasoningEffort: reasoningEffort, + Stream: reqStream, + OpenAIWSMode: false, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func logOpenAIPassthroughInstructionsRejected( + ctx context.Context, + c *gin.Context, + account *Account, + reqModel string, + rejectReason string, + body []byte, +) { + if ctx == nil { + ctx = context.Background() + } + accountID := int64(0) + accountName := "" + accountType := "" + if account != nil { + accountID = account.ID + accountName = strings.TrimSpace(account.Name) + accountType = strings.TrimSpace(string(account.Type)) + } + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.String("account_name", accountName), + zap.String("account_type", accountType), + zap.String("request_model", strings.TrimSpace(reqModel)), + zap.String("reject_reason", strings.TrimSpace(rejectReason)), + } + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) + logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions") +} + +func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := openaiPlatformAPIURL + switch account.Type { + case AccountTypeOAuth: + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + baseURL := account.GetOpenAIBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // 透传客户端请求头(安全白名单)。 + allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed() + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lower := strings.ToLower(strings.TrimSpace(key)) + if !isOpenAIPassthroughAllowedRequestHeader(lower, allowTimeoutHeaders) { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // 覆盖入站鉴权残留,并注入上游认证 + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Set("authorization", "Bearer "+token) + + // OAuth 透传到 ChatGPT internal API 时补齐必要头。 + if account.Type == AccountTypeOAuth { + promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + req.Host = "chatgpt.com" + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } + if req.Header.Get("accept") == "" { + req.Header.Set("accept", "text/event-stream") + } + if req.Header.Get("OpenAI-Beta") == "" { + req.Header.Set("OpenAI-Beta", "responses=experimental") + } + if req.Header.Get("originator") == "" { + req.Header.Set("originator", "codex_cli_rs") + } + if promptCacheKey != "" { + if req.Header.Get("conversation_id") == "" { + req.Header.Set("conversation_id", promptCacheKey) + } + if req.Header.Get("session_id") == "" { + req.Header.Set("session_id", promptCacheKey) + } + } + } + + // 透传模式也支持账户自定义 User-Agent 与 ForceCodexCLI 兜底。 + customUA := account.GetOpenAIUserAgent() + if customUA != "" { + req.Header.Set("user-agent", customUA) + } + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + req.Header.Set("user-agent", codexCLIUserAgent) + } + // OAuth 安全透传:对非 Codex UA 统一兜底,降低被上游风控拦截概率。 + if account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(req.Header.Get("user-agent")) { + req.Header.Set("user-agent", codexCLIUserAgent) + } + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + + return req, nil +} + +func (s *OpenAIGatewayService) handleErrorResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + requestBody []byte, +) error { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + UpstreamResponseBody: upstreamDetail, + }) + + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) +} + +func isOpenAIPassthroughAllowedRequestHeader(lowerKey string, allowTimeoutHeaders bool) bool { + if lowerKey == "" { + return false + } + if isOpenAIPassthroughTimeoutHeader(lowerKey) { + return allowTimeoutHeaders + } + return openaiPassthroughAllowedHeaders[lowerKey] +} + +func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool { + switch lowerKey { + case "x-stainless-timeout", "x-stainless-read-timeout", "x-stainless-connect-timeout", "x-request-timeout", "request-timeout", "grpc-timeout": + return true + default: + return false + } +} + +func (s *OpenAIGatewayService) isOpenAIPassthroughTimeoutHeadersAllowed() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIPassthroughAllowTimeoutHeaders +} + +func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string { + if h == nil { + return nil + } + var matched []string + for key, values := range h { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if isOpenAIPassthroughTimeoutHeader(lowerKey) { + entry := lowerKey + if len(values) > 0 { + entry = fmt.Sprintf("%s=%s", lowerKey, strings.Join(values, "|")) + } + matched = append(matched, entry) + } + } + sort.Strings(matched) + return matched +} + +type openaiStreamingResultPassthrough struct { + usage *OpenAIUsage + firstTokenMs *int +} + +func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, +) (*openaiStreamingResultPassthrough, error) { + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + // SSE headers + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &OpenAIUsage{} + var firstTokenMs *int + clientDisconnected := false + sawDone := false + upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + defer putSSEScannerBuf64K(scanBuf) + + for scanner.Scan() { + line := scanner.Text() + if data, ok := extractOpenAISSEDataLine(line); ok { + dataBytes := []byte(data) + trimmedData := strings.TrimSpace(data) + if trimmedData == "[DONE]" { + sawDone = true + } + if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsageBytes(dataBytes, usage) + } + + if !clientDisconnected { + if _, err := fmt.Fprintln(w, line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else { + flusher.Flush() + } + } + } + if err := scanner.Err(); err != nil { + if clientDisconnected { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v", + account.ID, + upstreamRequestID, + err, + ctx.Err(), + ) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if errors.Is(err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err + } + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v", + account.ID, + upstreamRequestID, + err, + ) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) + } + if !clientDisconnected && !sawDone && ctx.Err() == nil { + logger.FromContext(ctx).With( + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", account.ID), + zap.String("upstream_request_id", upstreamRequestID), + ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") + } + + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil +} + +func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, +) (*OpenAIUsage, error) { + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + usage := &OpenAIUsage{} + usageParsed := false + if len(body) > 0 { + if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(body); ok { + *usage = parsedUsage + usageParsed = true + } + } + if !usageParsed { + // 兜底:尝试从 SSE 文本中解析 usage + usage = s.parseSSEUsageFromBody(string(body)) + } + + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + return usage, nil +} + +func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { + if dst == nil || src == nil { + return + } + if filter != nil { + responseheaders.WriteFilteredHeaders(dst, src, filter) + } else { + // 兜底:尽量保留最基础的 content-type + if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { + dst.Set("Content-Type", v) + } + } + // 透传模式强制放行 x-codex-* 响应头(若上游返回)。 + // 注意:真实 http.Response.Header 的 key 一般会被 canonicalize;但为了兼容测试/自建响应, + // 这里用 EqualFold 做一次大小写不敏感的查找。 + getCaseInsensitiveValues := func(h http.Header, want string) []string { + if h == nil { + return nil + } + for k, vals := range h { + if strings.EqualFold(k, want) { + return vals + } + } + return nil + } + + for _, rawKey := range []string{ + "x-codex-primary-used-percent", + "x-codex-primary-reset-after-seconds", + "x-codex-primary-window-minutes", + "x-codex-secondary-used-percent", + "x-codex-secondary-reset-after-seconds", + "x-codex-secondary-window-minutes", + "x-codex-primary-over-secondary-limit-percent", + } { + vals := getCaseInsensitiveValues(src, rawKey) + if len(vals) == 0 { + continue + } + key := http.CanonicalHeaderKey(rawKey) + dst.Del(key) + for _, v := range vals { + dst.Add(key, v) + } + } +} + func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) { // Determine target URL based on account type var targetURL string @@ -999,7 +2572,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. if err != nil { return nil, err } - targetURL = validatedURL + "/responses" + targetURL = buildOpenAIResponsesURL(validatedURL) } default: targetURL = openaiPlatformAPIURL @@ -1053,6 +2626,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. req.Header.Set("user-agent", customUA) } + // 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。 + // 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。 + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + req.Header.Set("user-agent", codexCLIUserAgent) + } + // Ensure required headers exist if req.Header.Get("content-type") == "" { req.Header.Set("content-type", "application/json") @@ -1061,7 +2640,13 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. return req, nil } -func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) { +func (s *OpenAIGatewayService) handleErrorResponse( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + requestBody []byte, +) (*OpenAIForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) @@ -1075,9 +2660,10 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht upstreamDetail = truncateString(string(body), maxBytes) } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.openai_gateway", "OpenAI upstream error %d (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, @@ -1205,8 +2791,8 @@ type openaiStreamingResult struct { } func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } // Set SSE response headers @@ -1225,6 +2811,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if !ok { return nil, errors.New("streaming not supported") } + bufferedWriter := bufio.NewWriterSize(w, 4*1024) + flushBuffered := func() error { + if err := bufferedWriter.Flush(); err != nil { + return err + } + flusher.Flush() + return nil + } usage := &OpenAIUsage{} var firstTokenMs *int @@ -1233,38 +2827,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) - - type scanEvent struct { - line string - err error - } - // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 - events := make(chan scanEvent, 16) - done := make(chan struct{}) - sendEvent := func(ev scanEvent) bool { - select { - case events <- ev: - return true - case <-done: - return false - } - } - var lastReadAt int64 - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { - defer close(events) - for scanner.Scan() { - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - if !sendEvent(scanEvent{line: scanner.Text()}) { - return - } - } - if err := scanner.Err(); err != nil { - _ = sendEvent(scanEvent{err: err}) - } - }() - defer close(done) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) streamInterval := time.Duration(0) if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { @@ -1308,95 +2872,178 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return } errorEventSent = true - payload := map[string]any{ - "type": "error", - "sequence_number": 0, - "error": map[string]any{ - "type": "upstream_error", - "message": reason, - "code": reason, - }, + payload := `{"type":"error","sequence_number":0,"error":{"type":"upstream_error","message":` + strconv.Quote(reason) + `,"code":` + strconv.Quote(reason) + `}}` + if err := flushBuffered(); err != nil { + clientDisconnected = true + return } - if b, err := json.Marshal(payload); err == nil { - _, _ = fmt.Fprintf(w, "data: %s\n\n", b) - flusher.Flush() + if _, err := bufferedWriter.WriteString("data: " + payload + "\n\n"); err != nil { + clientDisconnected = true + return + } + if err := flushBuffered(); err != nil { + clientDisconnected = true } } needModelReplace := originalModel != mappedModel + resultWithUsage := func() *openaiStreamingResult { + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs} + } + finalizeStream := func() (*openaiStreamingResult, error) { + if !clientDisconnected { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") + } + } + return resultWithUsage(), nil + } + handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) { + if scanErr == nil { + return nil, nil, false + } + // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 + // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 + if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) { + logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage") + return resultWithUsage(), nil, true + } + // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage + if clientDisconnected { + logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr) + return resultWithUsage(), nil, true + } + if errors.Is(scanErr, bufio.ErrTooLong) { + logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr) + sendErrorEvent("response_too_large") + return resultWithUsage(), scanErr, true + } + sendErrorEvent("stream_read_error") + return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true + } + processSSELine := func(line string, queueDrained bool) { + lastDataAt = time.Now() + + // Extract data from SSE line (supports both "data: " and "data:" formats) + if data, ok := extractOpenAISSEDataLine(line); ok { + + // Replace model in response if needed. + // Fast path: most events do not contain model field values. + if needModelReplace && mappedModel != "" && strings.Contains(data, mappedModel) { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + } + + dataBytes := []byte(data) + + // Correct Codex tool calls if needed (apply_patch -> edit, etc.) + if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { + dataBytes = correctedData + data = string(correctedData) + line = "data: " + data + } + + // 写入客户端(客户端断开后继续 drain 上游) + if !clientDisconnected { + shouldFlush := queueDrained + if firstTokenMs == nil && data != "" && data != "[DONE]" { + // 保证首个 token 事件尽快出站,避免影响 TTFT。 + shouldFlush = true + } + if _, err := bufferedWriter.WriteString(line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if _, err := bufferedWriter.WriteString("\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if shouldFlush { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") + } + } + } + + // Record first token time + if firstTokenMs == nil && data != "" && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsageBytes(dataBytes, usage) + return + } + + // Forward non-data lines as-is + if !clientDisconnected { + if _, err := bufferedWriter.WriteString(line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if _, err := bufferedWriter.WriteString("\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if queueDrained { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") + } + } + } + } + + // 无超时/无 keepalive 的常见路径走同步扫描,减少 goroutine 与 channel 开销。 + if streamInterval <= 0 && keepaliveInterval <= 0 { + defer putSSEScannerBuf64K(scanBuf) + for scanner.Scan() { + processSSELine(scanner.Text(), true) + } + if result, err, done := handleScanErr(scanner.Err()); done { + return result, err + } + return finalizeStream() + } + + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) for { select { case ev, ok := <-events: if !ok { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + return finalizeStream() } - if ev.err != nil { - // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 - // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 - if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - log.Printf("Context canceled during streaming, returning collected usage") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil - } - // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage - if clientDisconnected { - log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err) - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil - } - if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) - sendErrorEvent("response_too_large") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err - } - sendErrorEvent("stream_read_error") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) - } - - line := ev.line - lastDataAt = time.Now() - - // Extract data from SSE line (supports both "data: " and "data:" formats) - if openaiSSEDataRe.MatchString(line) { - data := openaiSSEDataRe.ReplaceAllString(line, "") - - // Replace model in response if needed - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) - } - - // Correct Codex tool calls if needed (apply_patch -> edit, etc.) - if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { - data = correctedData - line = "data: " + correctedData - } - - // 写入客户端(客户端断开后继续 drain 上游) - if !clientDisconnected { - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") - } else { - flusher.Flush() - } - } - - // Record first token time - if firstTokenMs == nil && data != "" && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - s.parseSSEUsage(data, usage) - } else { - // Forward non-data lines as-is - if !clientDisconnected { - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") - } else { - flusher.Flush() - } - } + if result, err, done := handleScanErr(ev.err); done { + return result, err } + processSSELine(ev.line, len(events) == 0) case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) @@ -1404,16 +3051,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp continue } if clientDisconnected { - log.Printf("Upstream timeout after client disconnect, returning collected usage") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage") + return resultWithUsage(), nil } - log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 if s.rateLimitService != nil { s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) } sendErrorEvent("stream_timeout") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + return resultWithUsage(), fmt.Errorf("stream data interval timeout") case <-keepaliveCh: if clientDisconnected { @@ -1422,51 +3069,61 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if time.Since(lastDataAt) < keepaliveInterval { continue } - if _, err := fmt.Fprint(w, ":\n\n"); err != nil { + if _, err := bufferedWriter.WriteString(":\n\n"); err != nil { clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") continue } - flusher.Flush() + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing") + } } } } +// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。 +// 兼容 `data: xxx` 与 `data:xxx` 两种格式。 +func extractOpenAISSEDataLine(line string) (string, bool) { + if !strings.HasPrefix(line, "data:") { + return "", false + } + start := len("data:") + for start < len(line) { + if line[start] != ' ' && line[start] != ' ' { + break + } + start++ + } + return line[start:], true +} + func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { - if !openaiSSEDataRe.MatchString(line) { + data, ok := extractOpenAISSEDataLine(line) + if !ok { return line } - data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { return line } - var event map[string]any - if err := json.Unmarshal([]byte(data), &event); err != nil { - return line - } - - // Replace model in response - if m, ok := event["model"].(string); ok && m == fromModel { - event["model"] = toModel - newData, err := json.Marshal(event) + // 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化 + if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel { + newData, err := sjson.Set(data, "model", toModel) if err != nil { return line } - return "data: " + string(newData) + return "data: " + newData } - // Check nested response - if response, ok := event["response"].(map[string]any); ok { - if m, ok := response["model"].(string); ok && m == fromModel { - response["model"] = toModel - newData, err := json.Marshal(event) - if err != nil { - return line - } - return "data: " + string(newData) + // 检查嵌套的 response.model 字段 + if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel { + newData, err := sjson.Set(data, "response.model", toModel) + if err != nil { + return line } + return "data: " + newData } return line @@ -1478,39 +3135,64 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt return body } - bodyStr := string(body) - corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr) + corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(body) if changed { - return []byte(corrected) + return corrected } return body } func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { - // Parse response.completed event for usage (OpenAI Responses format) - var event struct { - Type string `json:"type"` - Response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } `json:"response"` + s.parseSSEUsageBytes([]byte(data), usage) +} + +func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsage) { + if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { + return + } + // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 + if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) { + return + } + if gjson.GetBytes(data, "type").String() != "response.completed" { + return } - if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" { - usage.InputTokens = event.Response.Usage.InputTokens - usage.OutputTokens = event.Response.Usage.OutputTokens - usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens + usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int()) + usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int()) + usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int()) +} + +func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { + if len(body) == 0 || !gjson.ValidBytes(body) { + return OpenAIUsage{}, false } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + ) + return OpenAIUsage{ + InputTokens: int(values[0].Int()), + OutputTokens: int(values[1].Int()), + CacheReadInputTokens: int(values[2].Int()), + }, true } func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { - body, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } @@ -1521,32 +3203,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r } } - // Parse usage - var response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } - if err := json.Unmarshal(body, &response); err != nil { - return nil, fmt.Errorf("parse response: %w", err) - } - - usage := &OpenAIUsage{ - InputTokens: response.Usage.InputTokens, - OutputTokens: response.Usage.OutputTokens, - CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens, + usageValue, usageOK := extractOpenAIUsageFromJSONBytes(body) + if !usageOK { + return nil, fmt.Errorf("parse response: invalid json response") } + usage := &usageValue // Replace model in response if needed if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json" if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { @@ -1571,19 +3239,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. usage := &OpenAIUsage{} if ok { - var response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } - if err := json.Unmarshal(finalResponse, &response); err == nil { - usage.InputTokens = response.Usage.InputTokens - usage.OutputTokens = response.Usage.OutputTokens - usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens + if parsedUsage, parsed := extractOpenAIUsageFromJSONBytes(finalResponse); parsed { + *usage = parsedUsage } body = finalResponse if originalModel != mappedModel { @@ -1599,7 +3256,7 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. body = []byte(bodyText) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json; charset=utf-8" if !ok { @@ -1616,23 +3273,17 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. func extractCodexFinalResponse(body string) ([]byte, bool) { lines := strings.Split(body, "\n") for _, line := range lines { - if !openaiSSEDataRe.MatchString(line) { + data, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { continue } - var event struct { - Type string `json:"type"` - Response json.RawMessage `json:"response"` - } - if json.Unmarshal([]byte(data), &event) != nil { - continue - } - if event.Type == "response.done" || event.Type == "response.completed" { - if len(event.Response) > 0 { - return event.Response, true + eventType := gjson.Get(data, "type").String() + if eventType == "response.done" || eventType == "response.completed" { + if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" { + return []byte(response.Raw), true } } } @@ -1643,14 +3294,14 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { usage := &OpenAIUsage{} lines := strings.Split(body, "\n") for _, line := range lines { - if !openaiSSEDataRe.MatchString(line) { + data, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { continue } - s.parseSSEUsage(data, usage) + s.parseSSEUsageBytes([]byte(data), usage) } return usage } @@ -1658,7 +3309,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string { lines := strings.Split(body, "\n") for i, line := range lines { - if !openaiSSEDataRe.MatchString(line) { + if _, ok := extractOpenAISSEDataLine(line); !ok { continue } lines[i] = s.replaceModelInSSELine(line, fromModel, toModel) @@ -1685,24 +3336,31 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro return normalized, nil } +// buildOpenAIResponsesURL 组装 OpenAI Responses 端点。 +// - base 以 /v1 结尾:追加 /responses +// - base 已是 /responses:原样返回 +// - 其他情况:追加 /v1/responses +func buildOpenAIResponsesURL(base string) string { + normalized := strings.TrimRight(strings.TrimSpace(base), "/") + if strings.HasSuffix(normalized, "/responses") { + return normalized + } + if strings.HasSuffix(normalized, "/v1") { + return normalized + "/responses" + } + return normalized + "/v1/responses" +} + func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return body + // 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化 + if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { + newBody, err := sjson.SetBytes(body, "model", toModel) + if err != nil { + return body + } + return newBody } - - model, ok := resp["model"].(string) - if !ok || model != fromModel { - return body - } - - resp["model"] = toModel - newBody, err := json.Marshal(resp) - if err != nil { - return body - } - - return newBody + return body } // OpenAIRecordUsageInput input for recording usage @@ -1782,6 +3440,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, Stream: result.Stream, + OpenAIWSMode: result.OpenAIWSMode, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, CreatedAt: time.Now(), @@ -1806,7 +3465,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec inserted, err := s.usageLogRepo.Create(ctx, usageLog) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } @@ -1829,7 +3488,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Update API key quota if applicable (only for balance mode with quota set) if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - log.Printf("Update API key quota failed: %v", err) + logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err) } } @@ -1907,16 +3566,41 @@ func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot { return snapshot } -// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field -func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { +func codexSnapshotBaseTime(snapshot *OpenAICodexUsageSnapshot, fallback time.Time) time.Time { if snapshot == nil { - return + return fallback + } + if snapshot.UpdatedAt == "" { + return fallback + } + base, err := time.Parse(time.RFC3339, snapshot.UpdatedAt) + if err != nil { + return fallback + } + return base +} + +func codexResetAtRFC3339(base time.Time, resetAfterSeconds *int) *string { + if resetAfterSeconds == nil { + return nil + } + sec := *resetAfterSeconds + if sec < 0 { + sec = 0 + } + resetAt := base.Add(time.Duration(sec) * time.Second).Format(time.RFC3339) + return &resetAt +} + +func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) map[string]any { + if snapshot == nil { + return nil } - // Convert snapshot to map for merging into Extra + baseTime := codexSnapshotBaseTime(snapshot, fallbackNow) updates := make(map[string]any) - // Save raw primary/secondary fields for debugging/tracing + // 保存原始 primary/secondary 字段,便于排查问题 if snapshot.PrimaryUsedPercent != nil { updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent } @@ -1938,9 +3622,9 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc if snapshot.PrimaryOverSecondaryPercent != nil { updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent } - updates["codex_usage_updated_at"] = snapshot.UpdatedAt + updates["codex_usage_updated_at"] = baseTime.Format(time.RFC3339) - // Normalize to canonical 5h/7d fields + // 归一化到 5h/7d 规范字段 if normalized := snapshot.Normalize(); normalized != nil { if normalized.Used5hPercent != nil { updates["codex_5h_used_percent"] = *normalized.Used5hPercent @@ -1960,6 +3644,29 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc if normalized.Window7dMinutes != nil { updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes } + if reset5hAt := codexResetAtRFC3339(baseTime, normalized.Reset5hSeconds); reset5hAt != nil { + updates["codex_5h_reset_at"] = *reset5hAt + } + if reset7dAt := codexResetAtRFC3339(baseTime, normalized.Reset7dSeconds); reset7dAt != nil { + updates["codex_7d_reset_at"] = *reset7dAt + } + } + + return updates +} + +// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field +func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { + if snapshot == nil { + return + } + if s == nil || s.accountRepo == nil { + return + } + + updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) + if len(updates) == 0 { + return } // Update account's Extra field asynchronously @@ -2016,6 +3723,106 @@ func deriveOpenAIReasoningEffortFromModel(model string) string { return normalizeOpenAIReasoningEffort(parts[len(parts)-1]) } +func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) { + if len(body) == 0 { + return "", false, "" + } + + model = strings.TrimSpace(gjson.GetBytes(body, "model").String()) + stream = gjson.GetBytes(body, "stream").Bool() + promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + return model, stream, promptCacheKey +} + +// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为: +// 1) store=false 2) stream=true +func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) { + if len(body) == 0 { + return body, false, nil + } + + normalized := body + changed := false + + if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False { + next, err := sjson.SetBytes(normalized, "store", false) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err) + } + normalized = next + changed = true + } + + if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True { + next, err := sjson.SetBytes(normalized, "stream", true) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err) + } + normalized = next + changed = true + } + + return normalized, changed, nil +} + +func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string { + model := strings.ToLower(strings.TrimSpace(reqModel)) + if !strings.Contains(model, "codex") { + return "" + } + + instructions := gjson.GetBytes(body, "instructions") + if !instructions.Exists() { + return "instructions_missing" + } + if instructions.Type != gjson.String { + return "instructions_not_string" + } + if strings.TrimSpace(instructions.String()) == "" { + return "instructions_empty" + } + return "" +} + +func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string { + reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if reasoningEffort == "" { + reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String()) + } + if reasoningEffort != "" { + normalized := normalizeOpenAIReasoningEffort(reasoningEffort) + if normalized == "" { + return nil + } + return &normalized + } + + value := deriveOpenAIReasoningEffortFromModel(requestedModel) + if value == "" { + return nil + } + return &value +} + +func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) { + if c != nil { + if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok { + if reqBody, ok := cached.(map[string]any); ok && reqBody != nil { + return reqBody, nil + } + } + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + return nil, fmt.Errorf("parse request: %w", err) + } + if c != nil { + c.Set(OpenAIParsedRequestBodyKey, reqBody) + } + return reqBody, nil +} + func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string { if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present { if value == "" { diff --git a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go new file mode 100644 index 00000000..d7c95ada --- /dev/null +++ b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go @@ -0,0 +1,266 @@ +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type stubCodexRestrictionDetector struct { + result CodexClientRestrictionDetectionResult +} + +func (s *stubCodexRestrictionDetector) Detect(_ *gin.Context, _ *Account) CodexClientRestrictionDetectionResult { + return s.result +} + +func TestOpenAIGatewayService_GetCodexClientRestrictionDetector(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("使用注入的 detector", func(t *testing.T) { + expected := &stubCodexRestrictionDetector{ + result: CodexClientRestrictionDetectionResult{Enabled: true, Matched: true, Reason: "stub"}, + } + svc := &OpenAIGatewayService{codexDetector: expected} + + got := svc.getCodexClientRestrictionDetector() + require.Same(t, expected, got) + }) + + t.Run("service 为 nil 时返回默认 detector", func(t *testing.T) { + var svc *OpenAIGatewayService + got := svc.getCodexClientRestrictionDetector() + require.NotNil(t, got) + }) + + t.Run("service 未注入 detector 时返回默认 detector", func(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: true}}} + got := svc.getCodexClientRestrictionDetector() + require.NotNil(t, got) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.Request.Header.Set("User-Agent", "curl/8.0") + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{"codex_cli_only": true}} + + result := got.Detect(c, account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason) + }) +} + +func TestGetAPIKeyIDFromContext(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("context 为 nil", func(t *testing.T) { + require.Equal(t, int64(0), getAPIKeyIDFromContext(nil)) + }) + + t.Run("上下文没有 api_key", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + require.Equal(t, int64(0), getAPIKeyIDFromContext(c)) + }) + + t.Run("api_key 类型错误", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Set("api_key", "not-api-key") + require.Equal(t, int64(0), getAPIKeyIDFromContext(c)) + }) + + t.Run("api_key 指针为空", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + var k *APIKey + c.Set("api_key", k) + require.Equal(t, int64(0), getAPIKeyIDFromContext(c)) + }) + + t.Run("正常读取 api_key_id", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Set("api_key", &APIKey{ID: 12345}) + require.Equal(t, int64(12345), getAPIKeyIDFromContext(c)) + }) +} + +func TestLogCodexCLIOnlyDetection_NilSafety(t *testing.T) { + // 不校验日志内容,仅保证在 nil 入参下不会 panic。 + require.NotPanics(t, func() { + logCodexCLIOnlyDetection(context.TODO(), nil, nil, 0, CodexClientRestrictionDetectionResult{Enabled: true, Matched: false, Reason: "test"}, nil) + logCodexCLIOnlyDetection(context.Background(), nil, nil, 0, CodexClientRestrictionDetectionResult{Enabled: false, Matched: false, Reason: "disabled"}, nil) + }) +} + +func TestLogCodexCLIOnlyDetection_OnlyLogsRejected(t *testing.T) { + logSink, restore := captureStructuredLog(t) + defer restore() + + account := &Account{ID: 1001} + logCodexCLIOnlyDetection(context.Background(), nil, account, 2002, CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: true, + Reason: CodexClientRestrictionReasonMatchedUA, + }, nil) + logCodexCLIOnlyDetection(context.Background(), nil, account, 2002, CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: false, + Reason: CodexClientRestrictionReasonNotMatchedUA, + }, nil) + + require.False(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求")) + require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 拒绝非官方客户端请求")) +} + +func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "assistants=v2") + + body := []byte(`{"model":"gpt-5.2","stream":false,"prompt_cache_key":"pc-123","access_token":"secret-token","input":[{"type":"text","text":"hello"}]}`) + account := &Account{ID: 1001} + logCodexCLIOnlyDetection(context.Background(), c, account, 2002, CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: false, + Reason: CodexClientRestrictionReasonNotMatchedUA, + }, body) + + require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")) + require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2")) + require.True(t, logSink.ContainsFieldValue("request_query", "trace=1")) + require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123"))) + require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta")) + require.True(t, logSink.ContainsField("request_body_size")) + require.False(t, logSink.ContainsField("request_body_preview")) +} + +func TestLogOpenAIInstructionsRequiredDebug_LogsRequestDetails(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "assistants=v2") + + body := []byte(`{"model":"gpt-5.1-codex","stream":false,"prompt_cache_key":"pc-abc","access_token":"secret-token","input":[{"type":"text","text":"hello"}]}`) + account := &Account{ID: 1001, Name: "codex max套餐"} + + logOpenAIInstructionsRequiredDebug( + context.Background(), + c, + account, + http.StatusBadRequest, + "Instructions are required", + body, + []byte(`{"error":{"message":"Instructions are required","type":"invalid_request_error","param":"instructions","code":"missing_required_parameter"}}`), + ) + + require.True(t, logSink.ContainsMessageAtLevel("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查", "warn")) + require.True(t, logSink.ContainsFieldValue("request_user_agent", "curl/8.0")) + require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.1-codex")) + require.True(t, logSink.ContainsFieldValue("request_query", "trace=1")) + require.True(t, logSink.ContainsFieldValue("account_name", "codex max套餐")) + require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta")) + require.True(t, logSink.ContainsField("request_body_size")) + require.False(t, logSink.ContainsField("request_body_preview")) +} + +func TestLogOpenAIInstructionsRequiredDebug_NonTargetErrorSkipped(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + body := []byte(`{"model":"gpt-5.1-codex","stream":false}`) + + logOpenAIInstructionsRequiredDebug( + context.Background(), + c, + &Account{ID: 1001}, + http.StatusForbidden, + "forbidden", + body, + []byte(`{"error":{"message":"forbidden"}}`), + ) + + require.False(t, logSink.ContainsMessage("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查")) +} + +func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "assistants=v2") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-upstream"}, + }, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Missing required parameter: 'instructions'","type":"invalid_request_error","param":"instructions","code":"missing_required_parameter"}}`)), + }, + } + svc := &OpenAIGatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ForceCodexCLI: false}, + }, + httpUpstream: upstream, + } + account := &Account{ + ID: 1001, + Name: "codex max套餐", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test"}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + body := []byte(`{"model":"gpt-5.1-codex","stream":false,"input":[{"type":"text","text":"hello"}],"prompt_cache_key":"pc-forward","access_token":"secret-token"}`) + + _, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Equal(t, http.StatusBadGateway, rec.Code) + require.Contains(t, err.Error(), "upstream error: 400") + + require.True(t, logSink.ContainsMessageAtLevel("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查", "warn")) + require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.1.0")) + require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.1-codex")) + require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta")) + require.True(t, logSink.ContainsField("request_body_size")) + require.False(t, logSink.ContainsField("request_body_preview")) +} diff --git a/backend/internal/service/openai_gateway_service_codex_snapshot_test.go b/backend/internal/service/openai_gateway_service_codex_snapshot_test.go new file mode 100644 index 00000000..654dd4ca --- /dev/null +++ b/backend/internal/service/openai_gateway_service_codex_snapshot_test.go @@ -0,0 +1,192 @@ +package service + +import ( + "testing" + "time" +) + +func TestCodexSnapshotBaseTime(t *testing.T) { + fallback := time.Date(2026, 2, 20, 9, 0, 0, 0, time.UTC) + + t.Run("nil snapshot uses fallback", func(t *testing.T) { + got := codexSnapshotBaseTime(nil, fallback) + if !got.Equal(fallback) { + t.Fatalf("got %v, want fallback %v", got, fallback) + } + }) + + t.Run("empty updatedAt uses fallback", func(t *testing.T) { + got := codexSnapshotBaseTime(&OpenAICodexUsageSnapshot{}, fallback) + if !got.Equal(fallback) { + t.Fatalf("got %v, want fallback %v", got, fallback) + } + }) + + t.Run("valid updatedAt wins", func(t *testing.T) { + got := codexSnapshotBaseTime(&OpenAICodexUsageSnapshot{UpdatedAt: "2026-02-16T10:00:00Z"}, fallback) + want := time.Date(2026, 2, 16, 10, 0, 0, 0, time.UTC) + if !got.Equal(want) { + t.Fatalf("got %v, want %v", got, want) + } + }) + + t.Run("invalid updatedAt uses fallback", func(t *testing.T) { + got := codexSnapshotBaseTime(&OpenAICodexUsageSnapshot{UpdatedAt: "invalid"}, fallback) + if !got.Equal(fallback) { + t.Fatalf("got %v, want fallback %v", got, fallback) + } + }) +} + +func TestCodexResetAtRFC3339(t *testing.T) { + base := time.Date(2026, 2, 16, 10, 0, 0, 0, time.UTC) + + t.Run("nil reset returns nil", func(t *testing.T) { + if got := codexResetAtRFC3339(base, nil); got != nil { + t.Fatalf("expected nil, got %v", *got) + } + }) + + t.Run("positive seconds", func(t *testing.T) { + sec := 90 + got := codexResetAtRFC3339(base, &sec) + if got == nil { + t.Fatal("expected non-nil") + } + if *got != "2026-02-16T10:01:30Z" { + t.Fatalf("got %s, want %s", *got, "2026-02-16T10:01:30Z") + } + }) + + t.Run("negative seconds clamp to base", func(t *testing.T) { + sec := -3 + got := codexResetAtRFC3339(base, &sec) + if got == nil { + t.Fatal("expected non-nil") + } + if *got != "2026-02-16T10:00:00Z" { + t.Fatalf("got %s, want %s", *got, "2026-02-16T10:00:00Z") + } + }) +} + +func TestBuildCodexUsageExtraUpdates_UsesSnapshotUpdatedAt(t *testing.T) { + primaryUsed := 88.0 + primaryReset := 86400 + primaryWindow := 10080 + secondaryUsed := 12.0 + secondaryReset := 3600 + secondaryWindow := 300 + + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &primaryUsed, + PrimaryResetAfterSeconds: &primaryReset, + PrimaryWindowMinutes: &primaryWindow, + SecondaryUsedPercent: &secondaryUsed, + SecondaryResetAfterSeconds: &secondaryReset, + SecondaryWindowMinutes: &secondaryWindow, + UpdatedAt: "2026-02-16T10:00:00Z", + } + + updates := buildCodexUsageExtraUpdates(snapshot, time.Date(2026, 2, 20, 8, 0, 0, 0, time.UTC)) + if updates == nil { + t.Fatal("expected non-nil updates") + } + + if got := updates["codex_usage_updated_at"]; got != "2026-02-16T10:00:00Z" { + t.Fatalf("codex_usage_updated_at = %v, want %s", got, "2026-02-16T10:00:00Z") + } + if got := updates["codex_5h_reset_at"]; got != "2026-02-16T11:00:00Z" { + t.Fatalf("codex_5h_reset_at = %v, want %s", got, "2026-02-16T11:00:00Z") + } + if got := updates["codex_7d_reset_at"]; got != "2026-02-17T10:00:00Z" { + t.Fatalf("codex_7d_reset_at = %v, want %s", got, "2026-02-17T10:00:00Z") + } +} + +func TestBuildCodexUsageExtraUpdates_FallbackToNowWhenUpdatedAtInvalid(t *testing.T) { + primaryUsed := 15.0 + primaryReset := 30 + primaryWindow := 300 + + fallbackNow := time.Date(2026, 2, 20, 8, 30, 0, 0, time.UTC) + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &primaryUsed, + PrimaryResetAfterSeconds: &primaryReset, + PrimaryWindowMinutes: &primaryWindow, + UpdatedAt: "invalid-time", + } + + updates := buildCodexUsageExtraUpdates(snapshot, fallbackNow) + if updates == nil { + t.Fatal("expected non-nil updates") + } + + if got := updates["codex_usage_updated_at"]; got != "2026-02-20T08:30:00Z" { + t.Fatalf("codex_usage_updated_at = %v, want %s", got, "2026-02-20T08:30:00Z") + } + if got := updates["codex_5h_reset_at"]; got != "2026-02-20T08:30:30Z" { + t.Fatalf("codex_5h_reset_at = %v, want %s", got, "2026-02-20T08:30:30Z") + } +} + +func TestBuildCodexUsageExtraUpdates_ClampNegativeResetSeconds(t *testing.T) { + primaryUsed := 90.0 + primaryReset := 7200 + primaryWindow := 10080 + secondaryUsed := 100.0 + secondaryReset := -15 + secondaryWindow := 300 + + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &primaryUsed, + PrimaryResetAfterSeconds: &primaryReset, + PrimaryWindowMinutes: &primaryWindow, + SecondaryUsedPercent: &secondaryUsed, + SecondaryResetAfterSeconds: &secondaryReset, + SecondaryWindowMinutes: &secondaryWindow, + UpdatedAt: "2026-02-16T10:00:00Z", + } + + updates := buildCodexUsageExtraUpdates(snapshot, time.Time{}) + if updates == nil { + t.Fatal("expected non-nil updates") + } + + if got := updates["codex_5h_reset_after_seconds"]; got != -15 { + t.Fatalf("codex_5h_reset_after_seconds = %v, want %d", got, -15) + } + if got := updates["codex_5h_reset_at"]; got != "2026-02-16T10:00:00Z" { + t.Fatalf("codex_5h_reset_at = %v, want %s", got, "2026-02-16T10:00:00Z") + } +} + +func TestBuildCodexUsageExtraUpdates_NilSnapshot(t *testing.T) { + if got := buildCodexUsageExtraUpdates(nil, time.Now()); got != nil { + t.Fatalf("expected nil updates, got %v", got) + } +} + +func TestBuildCodexUsageExtraUpdates_WithoutNormalizedWindowFields(t *testing.T) { + primaryUsed := 42.0 + fallbackNow := time.Date(2026, 2, 20, 9, 15, 0, 0, time.UTC) + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &primaryUsed, + UpdatedAt: "", + } + + updates := buildCodexUsageExtraUpdates(snapshot, fallbackNow) + if updates == nil { + t.Fatal("expected non-nil updates") + } + + if got := updates["codex_usage_updated_at"]; got != "2026-02-20T09:15:00Z" { + t.Fatalf("codex_usage_updated_at = %v, want %s", got, "2026-02-20T09:15:00Z") + } + if _, ok := updates["codex_5h_reset_at"]; ok { + t.Fatalf("did not expect codex_5h_reset_at in updates: %v", updates["codex_5h_reset_at"]) + } + if _, ok := updates["codex_7d_reset_at"]; ok { + t.Fatalf("did not expect codex_7d_reset_at in updates: %v", updates["codex_7d_reset_at"]) + } +} diff --git a/backend/internal/service/openai_gateway_service_hotpath_test.go b/backend/internal/service/openai_gateway_service_hotpath_test.go new file mode 100644 index 00000000..f73c06c5 --- /dev/null +++ b/backend/internal/service/openai_gateway_service_hotpath_test.go @@ -0,0 +1,141 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractOpenAIRequestMetaFromBody(t *testing.T) { + tests := []struct { + name string + body []byte + wantModel string + wantStream bool + wantPromptKey string + }{ + { + name: "完整字段", + body: []byte(`{"model":"gpt-5","stream":true,"prompt_cache_key":" ses-1 "}`), + wantModel: "gpt-5", + wantStream: true, + wantPromptKey: "ses-1", + }, + { + name: "缺失可选字段", + body: []byte(`{"model":"gpt-4"}`), + wantModel: "gpt-4", + wantStream: false, + wantPromptKey: "", + }, + { + name: "空请求体", + body: nil, + wantModel: "", + wantStream: false, + wantPromptKey: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + model, stream, promptKey := extractOpenAIRequestMetaFromBody(tt.body) + require.Equal(t, tt.wantModel, model) + require.Equal(t, tt.wantStream, stream) + require.Equal(t, tt.wantPromptKey, promptKey) + }) + } +} + +func TestExtractOpenAIReasoningEffortFromBody(t *testing.T) { + tests := []struct { + name string + body []byte + model string + wantNil bool + wantValue string + }{ + { + name: "优先读取 reasoning.effort", + body: []byte(`{"reasoning":{"effort":"medium"}}`), + model: "gpt-5-high", + wantNil: false, + wantValue: "medium", + }, + { + name: "兼容 reasoning_effort", + body: []byte(`{"reasoning_effort":"x-high"}`), + model: "", + wantNil: false, + wantValue: "xhigh", + }, + { + name: "minimal 归一化为空", + body: []byte(`{"reasoning":{"effort":"minimal"}}`), + model: "gpt-5-high", + wantNil: true, + }, + { + name: "缺失字段时从模型后缀推导", + body: []byte(`{"input":"hi"}`), + model: "gpt-5-high", + wantNil: false, + wantValue: "high", + }, + { + name: "未知后缀不返回", + body: []byte(`{"input":"hi"}`), + model: "gpt-5-unknown", + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractOpenAIReasoningEffortFromBody(tt.body, tt.model) + if tt.wantNil { + require.Nil(t, got) + return + } + require.NotNil(t, got) + require.Equal(t, tt.wantValue, *got) + }) + } +} + +func TestGetOpenAIRequestBodyMap_UsesContextCache(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + cached := map[string]any{"model": "cached-model", "stream": true} + c.Set(OpenAIParsedRequestBodyKey, cached) + + got, err := getOpenAIRequestBodyMap(c, []byte(`{invalid-json`)) + require.NoError(t, err) + require.Equal(t, cached, got) +} + +func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) { + _, err := getOpenAIRequestBodyMap(nil, []byte(`{invalid-json`)) + require.Error(t, err) + require.Contains(t, err.Error(), "parse request") +} + +func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + got, err := getOpenAIRequestBodyMap(c, []byte(`{"model":"gpt-5","stream":true}`)) + require.NoError(t, err) + require.Equal(t, "gpt-5", got["model"]) + + cached, ok := c.Get(OpenAIParsedRequestBodyKey) + require.True(t, ok) + cachedMap, ok := cached.(map[string]any) + require.True(t, ok) + require.Equal(t, got, cachedMap) +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 1c2c81ca..89443b69 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "errors" + "fmt" "io" "net/http" "net/http/httptest" @@ -13,9 +14,15 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/cespare/xxhash/v2" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" ) +// 编译期接口断言 +var _ AccountRepository = (*stubOpenAIAccountRepo)(nil) +var _ GatewayCache = (*stubGatewayCache)(nil) + type stubOpenAIAccountRepo struct { AccountRepository accounts []Account @@ -124,17 +131,19 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { svc := &OpenAIGatewayService{} + bodyWithKey := []byte(`{"prompt_cache_key":"ses_aaa"}`) + // 1) session_id header wins c.Request.Header.Set("session_id", "sess-123") c.Request.Header.Set("conversation_id", "conv-456") - h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + h1 := svc.GenerateSessionHash(c, bodyWithKey) if h1 == "" { t.Fatalf("expected non-empty hash") } // 2) conversation_id used when session_id absent c.Request.Header.Del("session_id") - h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + h2 := svc.GenerateSessionHash(c, bodyWithKey) if h2 == "" { t.Fatalf("expected non-empty hash") } @@ -144,7 +153,7 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { // 3) prompt_cache_key used when both headers absent c.Request.Header.Del("conversation_id") - h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + h3 := svc.GenerateSessionHash(c, bodyWithKey) if h3 == "" { t.Fatalf("expected non-empty hash") } @@ -153,12 +162,60 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { } // 4) empty when no signals - h4 := svc.GenerateSessionHash(c, map[string]any{}) + h4 := svc.GenerateSessionHash(c, []byte(`{}`)) if h4 != "" { t.Fatalf("expected empty hash when no signals") } } +func TestOpenAIGatewayService_GenerateSessionHash_UsesXXHash64(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + c.Request.Header.Set("session_id", "sess-fixed-value") + svc := &OpenAIGatewayService{} + + got := svc.GenerateSessionHash(c, nil) + want := fmt.Sprintf("%016x", xxhash.Sum64String("sess-fixed-value")) + require.Equal(t, want, got) +} + +func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + c.Request.Header.Set("session_id", "sess-legacy-check") + svc := &OpenAIGatewayService{} + + sessionHash := svc.GenerateSessionHash(c, nil) + require.NotEmpty(t, sessionHash) + require.NotNil(t, c.Request) + require.NotNil(t, c.Request.Context()) + require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context())) +} + +func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + svc := &OpenAIGatewayService{} + seed := "openai_ws_ingress:9:100:200" + + got := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), seed) + want := fmt.Sprintf("%016x", xxhash.Sum64String(seed)) + require.Equal(t, want, got) + require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context())) + + empty := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), " ") + require.Equal(t, "", empty) +} + func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { if c.waitCounts != nil { if count, ok := c.waitCounts[accountID]; ok { @@ -204,22 +261,6 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i return nil } -func (c *stubGatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { - return 0, nil -} - -func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) { - return nil, nil -} - -func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { now := time.Now() resetAt := now.Add(10 * time.Minute) @@ -1082,6 +1123,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { } } +func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 1, result.usage.InputTokens) + require.Equal(t, 2, result.usage.OutputTokens) + require.Equal(t, 3, result.usage.CacheReadInputTokens) +} + func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -1165,3 +1243,332 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) { t.Fatalf("expected non-allowlisted host to fail") } } + +// ==================== P1-08 修复:model 替换性能优化测试 ==================== + +func TestReplaceModelInSSELine(t *testing.T) { + svc := &OpenAIGatewayService{} + + tests := []struct { + name string + line string + from string + to string + expected string + }{ + { + name: "顶层 model 字段替换", + line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`, + from: "gpt-4o", + to: "my-custom-model", + expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`, + }, + { + name: "嵌套 response.model 替换", + line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`, + }, + { + name: "model 不匹配时不替换", + line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + }, + { + name: "无 model 字段时不替换", + line: `data: {"id":"chatcmpl-123","choices":[]}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"id":"chatcmpl-123","choices":[]}`, + }, + { + name: "空 data 行", + line: `data: `, + from: "gpt-4o", + to: "my-model", + expected: `data: `, + }, + { + name: "[DONE] 行", + line: `data: [DONE]`, + from: "gpt-4o", + to: "my-model", + expected: `data: [DONE]`, + }, + { + name: "非 data: 前缀行", + line: `event: message`, + from: "gpt-4o", + to: "my-model", + expected: `event: message`, + }, + { + name: "非法 JSON 不替换", + line: `data: {invalid json}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {invalid json}`, + }, + { + name: "无空格 data: 格式", + line: `data:{"id":"x","model":"gpt-4o"}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"id":"x","model":"my-model"}`, + }, + { + name: "model 名含特殊字符", + line: `data: {"model":"org/model-v2.1-beta"}`, + from: "org/model-v2.1-beta", + to: "custom/alias", + expected: `data: {"model":"custom/alias"}`, + }, + { + name: "空行", + line: "", + from: "gpt-4o", + to: "my-model", + expected: "", + }, + { + name: "保持其他字段不变", + line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`, + from: "gpt-4o", + to: "alias", + expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`, + }, + { + name: "顶层优先于嵌套:同时存在两个 model", + line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`, + from: "gpt-4o", + to: "replaced", + expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to) + require.Equal(t, tt.expected, got) + }) + } +} + +func TestReplaceModelInSSEBody(t *testing.T) { + svc := &OpenAIGatewayService{} + + tests := []struct { + name string + body string + from string + to string + expected string + }{ + { + name: "多行 SSE body 替换", + body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n", + from: "gpt-4o", + to: "alias", + expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n", + }, + { + name: "无需替换的 body", + body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n", + from: "gpt-4o", + to: "alias", + expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n", + }, + { + name: "混合 event 和 data 行", + body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n", + from: "gpt-4o", + to: "alias", + expected: "event: message\ndata: {\"model\":\"alias\"}\n\n", + }, + { + name: "空 body", + body: "", + from: "gpt-4o", + to: "alias", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to) + require.Equal(t, tt.expected, got) + }) + } +} + +func TestReplaceModelInResponseBody(t *testing.T) { + svc := &OpenAIGatewayService{} + + tests := []struct { + name string + body string + from string + to string + expected string + }{ + { + name: "替换顶层 model", + body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`, + from: "gpt-4o", + to: "alias", + expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`, + }, + { + name: "model 不匹配不替换", + body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + from: "gpt-4o", + to: "alias", + expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + }, + { + name: "无 model 字段不替换", + body: `{"id":"chatcmpl-123","choices":[]}`, + from: "gpt-4o", + to: "alias", + expected: `{"id":"chatcmpl-123","choices":[]}`, + }, + { + name: "非法 JSON 返回原值", + body: `not json`, + from: "gpt-4o", + to: "alias", + expected: `not json`, + }, + { + name: "空 body 返回原值", + body: ``, + from: "gpt-4o", + to: "alias", + expected: ``, + }, + { + name: "保持嵌套结构不变", + body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, + from: "gpt-4o", + to: "alias", + expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to) + require.Equal(t, tt.expected, string(got)) + }) + } +} + +func TestExtractOpenAISSEDataLine(t *testing.T) { + tests := []struct { + name string + line string + wantData string + wantOK bool + }{ + {name: "标准格式", line: `data: {"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true}, + {name: "无空格格式", line: `data:{"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true}, + {name: "纯空数据", line: `data: `, wantData: ``, wantOK: true}, + {name: "非 data 行", line: `event: message`, wantData: ``, wantOK: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := extractOpenAISSEDataLine(tt.line) + require.Equal(t, tt.wantOK, ok) + require.Equal(t, tt.wantData, got) + }) + } +} + +func TestParseSSEUsage_SelectiveParsing(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 9, OutputTokens: 8, CacheReadInputTokens: 7} + + // 非 completed 事件,不应覆盖 usage + svc.parseSSEUsage(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2}}}`, usage) + require.Equal(t, 9, usage.InputTokens) + require.Equal(t, 8, usage.OutputTokens) + require.Equal(t, 7, usage.CacheReadInputTokens) + + // completed 事件,应提取 usage + svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}}`, usage) + require.Equal(t, 3, usage.InputTokens) + require.Equal(t, 5, usage.OutputTokens) + require.Equal(t, 2, usage.CacheReadInputTokens) +} + +func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) { + body := strings.Join([]string{ + `event: message`, + `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`, + `data: {"type":"response.completed","response":{"id":"resp_1","model":"gpt-4o","usage":{"input_tokens":11,"output_tokens":22,"input_tokens_details":{"cached_tokens":3}}}}`, + `data: [DONE]`, + }, "\n") + + finalResp, ok := extractCodexFinalResponse(body) + require.True(t, ok) + require.Contains(t, string(finalResp), `"id":"resp_1"`) + require.Contains(t, string(finalResp), `"input_tokens":11`) +} + +func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.in_progress","response":{"id":"resp_2"}}`, + `data: {"type":"response.completed","response":{"id":"resp_2","model":"gpt-4o","usage":{"input_tokens":7,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, 7, usage.InputTokens) + require.Equal(t, 9, usage.OutputTokens) + require.Equal(t, 1, usage.CacheReadInputTokens) + // Header 可能由上游 Content-Type 透传;关键是 body 已转换为最终 JSON 响应。 + require.NotContains(t, rec.Body.String(), "event:") + require.Contains(t, rec.Body.String(), `"id":"resp_2"`) + require.NotContains(t, rec.Body.String(), "data:") +} + +func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.in_progress","response":{"id":"resp_3"}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, 0, usage.InputTokens) + require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream") + require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`) +} diff --git a/backend/internal/service/openai_json_optimization_benchmark_test.go b/backend/internal/service/openai_json_optimization_benchmark_test.go new file mode 100644 index 00000000..1737804b --- /dev/null +++ b/backend/internal/service/openai_json_optimization_benchmark_test.go @@ -0,0 +1,357 @@ +package service + +import ( + "encoding/json" + "strconv" + "strings" + "testing" + + "github.com/tidwall/gjson" +) + +var ( + benchmarkToolContinuationBoolSink bool + benchmarkWSParseStringSink string + benchmarkWSParseMapSink map[string]any + benchmarkUsageSink OpenAIUsage +) + +func BenchmarkToolContinuationValidationLegacy(b *testing.B) { + reqBody := benchmarkToolContinuationRequestBody() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkToolContinuationBoolSink = legacyValidateFunctionCallOutputContext(reqBody) + } +} + +func BenchmarkToolContinuationValidationOptimized(b *testing.B) { + reqBody := benchmarkToolContinuationRequestBody() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkToolContinuationBoolSink = optimizedValidateFunctionCallOutputContext(reqBody) + } +} + +func BenchmarkWSIngressPayloadParseLegacy(b *testing.B) { + raw := benchmarkWSIngressPayloadBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + eventType, model, promptCacheKey, previousResponseID, payload, err := legacyParseWSIngressPayload(raw) + if err == nil { + benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID + benchmarkWSParseMapSink = payload + } + } +} + +func BenchmarkWSIngressPayloadParseOptimized(b *testing.B) { + raw := benchmarkWSIngressPayloadBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + eventType, model, promptCacheKey, previousResponseID, payload, err := optimizedParseWSIngressPayload(raw) + if err == nil { + benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID + benchmarkWSParseMapSink = payload + } + } +} + +func BenchmarkOpenAIUsageExtractLegacy(b *testing.B) { + body := benchmarkOpenAIUsageJSONBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage, ok := legacyExtractOpenAIUsageFromJSONBytes(body) + if ok { + benchmarkUsageSink = usage + } + } +} + +func BenchmarkOpenAIUsageExtractOptimized(b *testing.B) { + body := benchmarkOpenAIUsageJSONBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage, ok := extractOpenAIUsageFromJSONBytes(body) + if ok { + benchmarkUsageSink = usage + } + } +} + +func benchmarkToolContinuationRequestBody() map[string]any { + input := make([]any, 0, 64) + for i := 0; i < 24; i++ { + input = append(input, map[string]any{ + "type": "text", + "text": "benchmark text", + }) + } + for i := 0; i < 10; i++ { + callID := "call_" + strconv.Itoa(i) + input = append(input, map[string]any{ + "type": "tool_call", + "call_id": callID, + }) + input = append(input, map[string]any{ + "type": "function_call_output", + "call_id": callID, + }) + input = append(input, map[string]any{ + "type": "item_reference", + "id": callID, + }) + } + return map[string]any{ + "model": "gpt-5.3-codex", + "input": input, + } +} + +func benchmarkWSIngressPayloadBytes() []byte { + return []byte(`{"type":"response.create","model":"gpt-5.3-codex","prompt_cache_key":"cache_bench","previous_response_id":"resp_prev_bench","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`) +} + +func benchmarkOpenAIUsageJSONBytes() []byte { + return []byte(`{"id":"resp_bench","object":"response","model":"gpt-5.3-codex","usage":{"input_tokens":3210,"output_tokens":987,"input_tokens_details":{"cached_tokens":456}}}`) +} + +func legacyValidateFunctionCallOutputContext(reqBody map[string]any) bool { + if !legacyHasFunctionCallOutput(reqBody) { + return true + } + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) != "" { + return true + } + if legacyHasToolCallContext(reqBody) { + return true + } + if legacyHasFunctionCallOutputMissingCallID(reqBody) { + return false + } + callIDs := legacyFunctionCallOutputCallIDs(reqBody) + return legacyHasItemReferenceForCallIDs(reqBody, callIDs) +} + +func optimizedValidateFunctionCallOutputContext(reqBody map[string]any) bool { + validation := ValidateFunctionCallOutputContext(reqBody) + if !validation.HasFunctionCallOutput { + return true + } + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) != "" { + return true + } + if validation.HasToolCallContext { + return true + } + if validation.HasFunctionCallOutputMissingCallID { + return false + } + return validation.HasItemReferenceForAllCallIDs +} + +func legacyHasFunctionCallOutput(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType == "function_call_output" { + return true + } + } + return false +} + +func legacyHasToolCallContext(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "tool_call" && itemType != "function_call" { + continue + } + if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + return true + } + } + return false +} + +func legacyFunctionCallOutputCallIDs(reqBody map[string]any) []string { + if reqBody == nil { + return nil + } + input, ok := reqBody["input"].([]any) + if !ok { + return nil + } + ids := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "function_call_output" { + continue + } + if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + ids[callID] = struct{}{} + } + } + if len(ids) == 0 { + return nil + } + callIDs := make([]string, 0, len(ids)) + for id := range ids { + callIDs = append(callIDs, id) + } + return callIDs +} + +func legacyHasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "function_call_output" { + continue + } + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) == "" { + return true + } + } + return false +} + +func legacyHasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool { + if reqBody == nil || len(callIDs) == 0 { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + referenceIDs := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "item_reference" { + continue + } + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + referenceIDs[idValue] = struct{}{} + } + if len(referenceIDs) == 0 { + return false + } + for _, callID := range callIDs { + if _, ok := referenceIDs[callID]; !ok { + return false + } + } + return true +} + +func legacyParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) { + values := gjson.GetManyBytes(raw, "type", "model", "prompt_cache_key", "previous_response_id") + eventType = strings.TrimSpace(values[0].String()) + if eventType == "" { + eventType = "response.create" + } + model = strings.TrimSpace(values[1].String()) + promptCacheKey = strings.TrimSpace(values[2].String()) + previousResponseID = strings.TrimSpace(values[3].String()) + payload = make(map[string]any) + if err = json.Unmarshal(raw, &payload); err != nil { + return "", "", "", "", nil, err + } + if _, exists := payload["type"]; !exists { + payload["type"] = "response.create" + } + return eventType, model, promptCacheKey, previousResponseID, payload, nil +} + +func optimizedParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) { + payload = make(map[string]any) + if err = json.Unmarshal(raw, &payload); err != nil { + return "", "", "", "", nil, err + } + eventType = openAIWSPayloadString(payload, "type") + if eventType == "" { + eventType = "response.create" + payload["type"] = eventType + } + model = openAIWSPayloadString(payload, "model") + promptCacheKey = openAIWSPayloadString(payload, "prompt_cache_key") + previousResponseID = openAIWSPayloadString(payload, "previous_response_id") + return eventType, model, promptCacheKey, previousResponseID, payload, nil +} + +func legacyExtractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { + var response struct { + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails struct { + CachedTokens int `json:"cached_tokens"` + } `json:"input_tokens_details"` + } `json:"usage"` + } + if err := json.Unmarshal(body, &response); err != nil { + return OpenAIUsage{}, false + } + return OpenAIUsage{ + InputTokens: response.Usage.InputTokens, + OutputTokens: response.Usage.OutputTokens, + CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens, + }, true +} diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go new file mode 100644 index 00000000..0840d3b1 --- /dev/null +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -0,0 +1,928 @@ +package service + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func f64p(v float64) *float64 { return &v } + +type httpUpstreamRecorder struct { + lastReq *http.Request + lastBody []byte + + resp *http.Response + err error +} + +func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + u.lastReq = req + if req != nil && req.Body != nil { + b, _ := io.ReadAll(req.Body) + u.lastBody = b + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewReader(b)) + } + if u.err != nil { + return nil, u.err + } + return u.resp, nil +} + +func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +var structuredLogCaptureMu sync.Mutex + +type inMemoryLogSink struct { + mu sync.Mutex + events []*logger.LogEvent +} + +func (s *inMemoryLogSink) WriteLogEvent(event *logger.LogEvent) { + if event == nil { + return + } + cloned := *event + if event.Fields != nil { + cloned.Fields = make(map[string]any, len(event.Fields)) + for k, v := range event.Fields { + cloned.Fields[k] = v + } + } + s.mu.Lock() + s.events = append(s.events, &cloned) + s.mu.Unlock() +} + +func (s *inMemoryLogSink) ContainsMessage(substr string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, ev := range s.events { + if ev != nil && strings.Contains(ev.Message, substr) { + return true + } + } + return false +} + +func (s *inMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool { + s.mu.Lock() + defer s.mu.Unlock() + wantLevel := strings.ToLower(strings.TrimSpace(level)) + for _, ev := range s.events { + if ev == nil { + continue + } + if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel { + return true + } + } + return false +} + +func (s *inMemoryLogSink) ContainsFieldValue(field, substr string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, ev := range s.events { + if ev == nil || ev.Fields == nil { + continue + } + if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) { + return true + } + } + return false +} + +func (s *inMemoryLogSink) ContainsField(field string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, ev := range s.events { + if ev == nil || ev.Fields == nil { + continue + } + if _, ok := ev.Fields[field]; ok { + return true + } + } + return false +} + +func captureStructuredLog(t *testing.T) (*inMemoryLogSink, func()) { + t.Helper() + structuredLogCaptureMu.Lock() + + err := logger.Init(logger.InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: logger.SamplingOptions{Enabled: false}, + }) + require.NoError(t, err) + + sink := &inMemoryLogSink{} + logger.SetSink(sink) + return sink, func() { + logger.SetSink(nil) + structuredLogCaptureMu.Unlock() + } +} + +func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormalized(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("Authorization", "Bearer inbound-should-not-forward") + c.Request.Header.Set("Cookie", "secret=1") + c.Request.Header.Set("X-Api-Key", "sk-inbound") + c.Request.Header.Set("X-Goog-Api-Key", "goog-inbound") + c.Request.Header.Set("Accept-Encoding", "gzip") + c.Request.Header.Set("Proxy-Authorization", "Basic abc") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) + + upstreamSSE := strings.Join([]string{ + `data: {"type":"response.output_item.added","item":{"type":"tool_call","tool_calls":[{"function":{"name":"apply_patch"}}]}}`, + "", + "data: [DONE]", + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + openAITokenProvider: &OpenAITokenProvider{ // minimal: will be bypassed by nil cache/service, but GetAccessToken uses provider only if non-nil + accountRepo: nil, + }, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + // Use the gateway method that reads token from credentials when provider is nil. + svc.openAITokenProvider = nil + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + + // 1) 透传 OAuth 请求体与旧链路关键行为保持一致:store=false + stream=true。 + require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) + require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + require.Equal(t, "local-test-instructions", strings.TrimSpace(gjson.GetBytes(upstream.lastBody, "instructions").String())) + // 其余关键字段保持原值。 + require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String()) + + // 2) only auth is replaced; inbound auth/cookie are not forwarded + require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "codex_cli_rs/0.1.0", upstream.lastReq.Header.Get("User-Agent")) + require.Empty(t, upstream.lastReq.Header.Get("Cookie")) + require.Empty(t, upstream.lastReq.Header.Get("X-Api-Key")) + require.Empty(t, upstream.lastReq.Header.Get("X-Goog-Api-Key")) + require.Empty(t, upstream.lastReq.Header.Get("Accept-Encoding")) + require.Empty(t, upstream.lastReq.Header.Get("Proxy-Authorization")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) + + // 3) required OAuth headers are present + require.Equal(t, "chatgpt.com", upstream.lastReq.Host) + require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id")) + + // 4) downstream SSE keeps tool name (no toolCorrector) + body := rec.Body.String() + require.Contains(t, body, "apply_patch") + require.NotContains(t, body, "\"name\":\"edit\"") +} + +func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "responses=experimental") + + // Codex 模型且缺少 instructions,应在本地直接 403 拒绝,不触达上游。 + originalBody := []byte(`{"model":"gpt-5.1-codex-max","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "requires a non-empty instructions field") + require.Nil(t, upstream.lastReq) + + require.True(t, logSink.ContainsMessage("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions")) + require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")) + require.True(t, logSink.ContainsFieldValue("reject_reason", "instructions_missing")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + // store=true + stream=false should be forced to store=false + stream=true by applyCodexOAuthTransform (OAuth legacy path) + inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": false}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + + // legacy path rewrites request body (not byte-equal) + require.NotEqual(t, inputBody, upstream.lastBody) + require.Contains(t, string(upstream.lastBody), `"store":false`) + require.Contains(t, string(upstream.lastBody), `"stream":true`) +} + +func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + // 复合 UA(前缀不是 codex_cli_rs),历史实现会误判为非 Codex 并走 opencode。 + c.Request.Header.Set("User-Agent", "Mozilla/5.0 codex_cli_rs/0.1.0") + + inputBody := []byte(`{"model":"gpt-5.2","stream":true,"store":false,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": false}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "codex_cli_rs", upstream.lastReq.Header.Get("originator")) + require.NotEqual(t, "opencode", upstream.lastReq.Header.Get("originator")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + + headers := make(http.Header) + headers.Set("Content-Type", "application/json") + headers.Set("x-request-id", "rid") + headers.Set("x-codex-primary-used-percent", "12") + headers.Set("x-codex-secondary-used-percent", "34") + headers.Set("x-codex-primary-window-minutes", "300") + headers.Set("x-codex-secondary-window-minutes", "10080") + headers.Set("x-codex-primary-reset-after-seconds", "1") + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: headers, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + + require.Equal(t, "12", rec.Header().Get("x-codex-primary-used-percent")) + require.Equal(t, "34", rec.Header().Get("x-codex-secondary-used-percent")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughFlag(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"bad"}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + + // should append an upstream error event with passthrough=true + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + arr, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.NotEmpty(t, arr) + require.True(t, arr[len(arr)-1].Passthrough) +} + +func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + // Non-Codex UA + c.Request.Header.Set("User-Agent", "curl/8.0") + + inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) + require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent")) +} + +func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + + inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true, "codex_cli_only": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.Error(t, err) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "Codex official clients") +} + +func TestOpenAIGatewayService_CodexCLIOnly_AllowOfficialClientFamilies(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + ua string + originator string + }{ + {name: "codex_cli_rs", ua: "codex_cli_rs/0.99.0", originator: ""}, + {name: "codex_vscode", ua: "codex_vscode/1.0.0", originator: ""}, + {name: "codex_app", ua: "codex_app/2.1.0", originator: ""}, + {name: "originator_codex_chatgpt_desktop", ua: "curl/8.0", originator: "codex_chatgpt_desktop"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", tt.ua) + if tt.originator != "" { + c.Request.Header.Set("originator", tt.originator) + } + + inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true, "codex_cli_only": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + }) + } +} + +func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + + upstreamSSE := strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"h"}`, + "", + "data: [DONE]", + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + start := time.Now() + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + // sanity: duration after start + require.GreaterOrEqual(t, time.Since(start), time.Duration(0)) + require.NotNil(t, result.FirstTokenMs) + require.GreaterOrEqual(t, *result.FirstTokenMs, 0) +} + +func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollectsUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + // 首次写入成功,后续写入失败,模拟客户端中途断开。 + c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 1} + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + + upstreamSSE := strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"h"}`, + "", + `data: {"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.NotNil(t, result.FirstTokenMs) + require.Equal(t, 11, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) +} + +func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEndpoint(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 456, + Name: "apikey-acc", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-api-key", "base_url": "https://api.openai.com"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, originalBody, upstream.lastBody) + require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String()) + require.Equal(t, "Bearer sk-api-key", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("x-stainless-timeout", "10000") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-timeout"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 321, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.True(t, logSink.ContainsMessage("检测到超时相关请求头,将按配置过滤以降低断流风险")) + require.True(t, logSink.ContainsFieldValue("timeout_headers", "x-stainless-timeout=10000")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_InfoWhenStreamEndsWithoutDone(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + // 注意:刻意不发送 [DONE],模拟上游中途断流。 + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-truncate"}}, + Body: io.NopCloser(strings.NewReader("data: {\"type\":\"response.output_text.delta\",\"delta\":\"h\"}\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 654, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.True(t, logSink.ContainsMessage("上游流在未收到 [DONE] 时结束,疑似断流")) + require.True(t, logSink.ContainsMessageAtLevel("上游流在未收到 [DONE] 时结束,疑似断流", "info")) + require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-truncate")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("x-stainless-timeout", "120000") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-default"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 111, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Empty(t, upstream.lastReq.Header.Get("x-stainless-timeout")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("x-stainless-timeout", "120000") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-allow"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ + ForceCodexCLI: false, + OpenAIPassthroughAllowTimeoutHeaders: true, + }}, + httpUpstream: upstream, + } + account := &Account{ + ID: 222, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "120000", upstream.lastReq.Header.Get("x-stainless-timeout")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) +} diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index ca7470b9..72f4bbb0 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -2,13 +2,31 @@ package service import ( "context" + "crypto/subtle" + "encoding/json" + "io" + "log/slog" "net/http" + "regexp" + "sort" + "strconv" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) +var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" + +var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`) + +type soraSessionChunk struct { + index int + value string +} + // OpenAIOAuthService handles OpenAI OAuth authentication flows type OpenAIOAuthService struct { sessionStore *openai.SessionStore @@ -32,7 +50,7 @@ type OpenAIAuthURLResult struct { } // GenerateAuthURL generates an OpenAI OAuth authorization URL -func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) { +func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, platform string) (*OpenAIAuthURLResult, error) { // Generate PKCE values state, err := openai.GenerateState() if err != nil { @@ -68,11 +86,14 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 if redirectURI == "" { redirectURI = openai.DefaultRedirectURI } + normalizedPlatform := normalizeOpenAIOAuthPlatform(platform) + clientID, _ := openai.OAuthClientConfigByPlatform(normalizedPlatform) // Store session session := &openai.OAuthSession{ State: state, CodeVerifier: codeVerifier, + ClientID: clientID, RedirectURI: redirectURI, ProxyURL: proxyURL, CreatedAt: time.Now(), @@ -80,7 +101,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 s.sessionStore.Set(sessionID, session) // Build authorization URL - authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI) + authURL := openai.BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, normalizedPlatform) return &OpenAIAuthURLResult{ AuthURL: authURL, @@ -92,6 +113,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 type OpenAIExchangeCodeInput struct { SessionID string Code string + State string RedirectURI string ProxyID *int64 } @@ -103,6 +125,7 @@ type OpenAITokenInfo struct { IDToken string `json:"id_token,omitempty"` ExpiresIn int64 `json:"expires_in"` ExpiresAt int64 `json:"expires_at"` + ClientID string `json:"client_id,omitempty"` Email string `json:"email,omitempty"` ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"` ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` @@ -116,6 +139,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch if !ok { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired") } + if input.State == "" { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required") + } + if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state") + } // Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL proxyURL := session.ProxyURL @@ -134,9 +163,13 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch if input.RedirectURI != "" { redirectURI = input.RedirectURI } + clientID := strings.TrimSpace(session.ClientID) + if clientID == "" { + clientID = openai.ClientID + } // Exchange code for token - tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL) + tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL, clientID) if err != nil { return nil, err } @@ -144,8 +177,10 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch // Parse ID token to get user info var userInfo *openai.UserInfo if tokenResp.IDToken != "" { - claims, err := openai.ParseIDToken(tokenResp.IDToken) - if err == nil { + claims, parseErr := openai.ParseIDToken(tokenResp.IDToken) + if parseErr != nil { + slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr) + } else { userInfo = claims.GetUserInfo() } } @@ -159,6 +194,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch IDToken: tokenResp.IDToken, ExpiresIn: int64(tokenResp.ExpiresIn), ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn), + ClientID: clientID, } if userInfo != nil { @@ -173,7 +209,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch // RefreshToken refreshes an OpenAI OAuth token func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) { - tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") +} + +// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id. +func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) { + tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) if err != nil { return nil, err } @@ -181,8 +222,10 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri // Parse ID token to get user info var userInfo *openai.UserInfo if tokenResp.IDToken != "" { - claims, err := openai.ParseIDToken(tokenResp.IDToken) - if err == nil { + claims, parseErr := openai.ParseIDToken(tokenResp.IDToken) + if parseErr != nil { + slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr) + } else { userInfo = claims.GetUserInfo() } } @@ -194,6 +237,9 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri ExpiresIn: int64(tokenResp.ExpiresIn), ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn), } + if trimmed := strings.TrimSpace(clientID); trimmed != "" { + tokenInfo.ClientID = trimmed + } if userInfo != nil { tokenInfo.Email = userInfo.Email @@ -205,13 +251,221 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri return tokenInfo, nil } -// RefreshAccountToken refreshes token for an OpenAI account -func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { - if !account.IsOpenAI() { - return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account") +// ExchangeSoraSessionToken exchanges Sora session_token to access_token. +func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) { + sessionToken = normalizeSoraSessionTokenInput(sessionToken) + if strings.TrimSpace(sessionToken) == "" { + return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required") } - refreshToken := account.GetOpenAIRefreshToken() + proxyURL, err := s.resolveProxyURL(ctx, proxyID) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil) + if err != nil { + return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err) + } + req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: 120 * time.Second, + }) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err) + } + resp, err := client.Do(req) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if resp.StatusCode != http.StatusOK { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var sessionResp struct { + AccessToken string `json:"accessToken"` + Expires string `json:"expires"` + User struct { + Email string `json:"email"` + Name string `json:"name"` + } `json:"user"` + } + if err := json.Unmarshal(body, &sessionResp); err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err) + } + if strings.TrimSpace(sessionResp.AccessToken) == "" { + return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token") + } + + expiresAt := time.Now().Add(time.Hour).Unix() + if strings.TrimSpace(sessionResp.Expires) != "" { + if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil { + expiresAt = parsed.Unix() + } + } + expiresIn := expiresAt - time.Now().Unix() + if expiresIn < 0 { + expiresIn = 0 + } + + return &OpenAITokenInfo{ + AccessToken: strings.TrimSpace(sessionResp.AccessToken), + ExpiresIn: expiresIn, + ExpiresAt: expiresAt, + ClientID: openai.SoraClientID, + Email: strings.TrimSpace(sessionResp.User.Email), + }, nil +} + +func normalizeSoraSessionTokenInput(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + + matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1) + if len(matches) == 0 { + return sanitizeSessionToken(trimmed) + } + + chunkMatches := make([]soraSessionChunk, 0, len(matches)) + singleValues := make([]string, 0, len(matches)) + + for _, match := range matches { + if len(match) < 3 { + continue + } + + value := sanitizeSessionToken(match[2]) + if value == "" { + continue + } + + if strings.TrimSpace(match[1]) == "" { + singleValues = append(singleValues, value) + continue + } + + idx, err := strconv.Atoi(strings.TrimSpace(match[1])) + if err != nil || idx < 0 { + continue + } + chunkMatches = append(chunkMatches, soraSessionChunk{ + index: idx, + value: value, + }) + } + + if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" { + return merged + } + + if len(singleValues) > 0 { + return singleValues[len(singleValues)-1] + } + + return "" +} + +func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string { + if len(chunks) == 0 { + return "" + } + + byIndex := make(map[int]string, len(chunks)) + for _, chunk := range chunks { + byIndex[chunk.index] = chunk.value + } + + if _, ok := byIndex[0]; !ok { + return "" + } + if requireComplete { + for idx := 0; idx <= requiredMaxIndex; idx++ { + if _, ok := byIndex[idx]; !ok { + return "" + } + } + } + + orderedIndexes := make([]int, 0, len(byIndex)) + for idx := range byIndex { + orderedIndexes = append(orderedIndexes, idx) + } + sort.Ints(orderedIndexes) + + var builder strings.Builder + for _, idx := range orderedIndexes { + if _, err := builder.WriteString(byIndex[idx]); err != nil { + return "" + } + } + return sanitizeSessionToken(builder.String()) +} + +func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string { + if len(chunks) == 0 { + return "" + } + + requiredMaxIndex := 0 + for _, chunk := range chunks { + if chunk.index > requiredMaxIndex { + requiredMaxIndex = chunk.index + } + } + + groupStarts := make([]int, 0, len(chunks)) + for idx, chunk := range chunks { + if chunk.index == 0 { + groupStarts = append(groupStarts, idx) + } + } + + if len(groupStarts) == 0 { + return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false) + } + + for i := len(groupStarts) - 1; i >= 0; i-- { + start := groupStarts[i] + end := len(chunks) + if i+1 < len(groupStarts) { + end = groupStarts[i+1] + } + if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" { + return merged + } + } + + return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false) +} + +func sanitizeSessionToken(raw string) string { + token := strings.TrimSpace(raw) + token = strings.Trim(token, "\"'`") + token = strings.TrimSuffix(token, ";") + return strings.TrimSpace(token) +} + +// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account +func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { + if account.Platform != PlatformOpenAI && account.Platform != PlatformSora { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account") + } + if account.Type != AccountTypeOAuth { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account") + } + + refreshToken := account.GetCredential("refresh_token") if refreshToken == "" { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") } @@ -224,7 +478,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A } } - return s.RefreshToken(ctx, refreshToken, proxyURL) + clientID := account.GetCredential("client_id") + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) } // BuildAccountCredentials builds credentials map from token info @@ -232,9 +487,12 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339) creds := map[string]any{ - "access_token": tokenInfo.AccessToken, - "refresh_token": tokenInfo.RefreshToken, - "expires_at": expiresAt, + "access_token": tokenInfo.AccessToken, + "expires_at": expiresAt, + } + // 仅在刷新响应返回了新的 refresh_token 时才更新,防止用空值覆盖已有令牌 + if strings.TrimSpace(tokenInfo.RefreshToken) != "" { + creds["refresh_token"] = tokenInfo.RefreshToken } if tokenInfo.IDToken != "" { @@ -252,6 +510,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) if tokenInfo.OrganizationID != "" { creds["organization_id"] = tokenInfo.OrganizationID } + if strings.TrimSpace(tokenInfo.ClientID) != "" { + creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID) + } return creds } @@ -260,3 +521,26 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) func (s *OpenAIOAuthService) Stop() { s.sessionStore.Stop() } + +func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) { + if proxyID == nil { + return "", nil + } + proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) + if err != nil { + return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err) + } + if proxy == nil { + return "", nil + } + return proxy.URL(), nil +} + +func normalizeOpenAIOAuthPlatform(platform string) string { + switch strings.ToLower(strings.TrimSpace(platform)) { + case PlatformSora: + return openai.OAuthPlatformSora + default: + return openai.OAuthPlatformOpenAI + } +} diff --git a/backend/internal/service/openai_oauth_service_auth_url_test.go b/backend/internal/service/openai_oauth_service_auth_url_test.go new file mode 100644 index 00000000..5f26903d --- /dev/null +++ b/backend/internal/service/openai_oauth_service_auth_url_test.go @@ -0,0 +1,67 @@ +package service + +import ( + "context" + "errors" + "net/url" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientAuthURLStub struct{} + +func (s *openaiOAuthClientAuthURLStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientAuthURLStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientAuthURLStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) { + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{}) + defer svc.Stop() + + result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformOpenAI) + require.NoError(t, err) + require.NotEmpty(t, result.AuthURL) + require.NotEmpty(t, result.SessionID) + + parsed, err := url.Parse(result.AuthURL) + require.NoError(t, err) + q := parsed.Query() + require.Equal(t, openai.ClientID, q.Get("client_id")) + require.Equal(t, "true", q.Get("codex_cli_simplified_flow")) + + session, ok := svc.sessionStore.Get(result.SessionID) + require.True(t, ok) + require.Equal(t, openai.ClientID, session.ClientID) +} + +// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的 +// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。 +func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) { + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{}) + defer svc.Stop() + + result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora) + require.NoError(t, err) + require.NotEmpty(t, result.AuthURL) + require.NotEmpty(t, result.SessionID) + + parsed, err := url.Parse(result.AuthURL) + require.NoError(t, err) + q := parsed.Query() + require.Equal(t, openai.ClientID, q.Get("client_id")) + require.Empty(t, q.Get("codex_cli_simplified_flow")) + + session, ok := svc.sessionStore.Get(result.SessionID) + require.True(t, ok) + require.Equal(t, openai.ClientID, session.ClientID) +} diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go new file mode 100644 index 00000000..08da8557 --- /dev/null +++ b/backend/internal/service/openai_oauth_service_sora_session_test.go @@ -0,0 +1,173 @@ +package service + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientNoopStub struct{} + +func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "at-token", info.AccessToken) + require.Equal(t, "demo@example.com", info.Email) + require.Greater(t, info.ExpiresAt, int64(0)) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + _, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "missing access token") +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax" + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := strings.Join([]string{ + "Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly", + }, "\n") + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := strings.Join([]string{ + "Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly", + }, "\n") + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := strings.Join([]string{ + "set-cookie", + "__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/", + "set-cookie", + "__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/", + "set-cookie", + "__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/", + }, "\n") + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go new file mode 100644 index 00000000..29252328 --- /dev/null +++ b/backend/internal/service/openai_oauth_service_state_test.go @@ -0,0 +1,106 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientStateStub struct { + exchangeCalled int32 + lastClientID string +} + +func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.exchangeCalled, 1) + s.lastClientID = clientID + return &openai.TokenResponse{ + AccessToken: "at", + RefreshToken: "rt", + ExpiresIn: 3600, + }, nil +} + +func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return s.RefreshToken(ctx, refreshToken, proxyURL) +} + +func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "oauth state is required") + require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) +} + +func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + State: "wrong-state", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid oauth state") + require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) +} + +func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + State: "expected-state", + }) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "at", info.AccessToken) + require.Equal(t, openai.ClientID, info.ClientID) + require.Equal(t, openai.ClientID, client.lastClientID) + require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled)) + + _, ok := svc.sessionStore.Get("sid") + require.False(t, ok) +} diff --git a/backend/internal/service/openai_previous_response_id.go b/backend/internal/service/openai_previous_response_id.go new file mode 100644 index 00000000..95865086 --- /dev/null +++ b/backend/internal/service/openai_previous_response_id.go @@ -0,0 +1,37 @@ +package service + +import ( + "regexp" + "strings" +) + +const ( + OpenAIPreviousResponseIDKindEmpty = "empty" + OpenAIPreviousResponseIDKindResponseID = "response_id" + OpenAIPreviousResponseIDKindMessageID = "message_id" + OpenAIPreviousResponseIDKindUnknown = "unknown" +) + +var ( + openAIResponseIDPattern = regexp.MustCompile(`^resp_[A-Za-z0-9_-]{1,256}$`) + openAIMessageIDPattern = regexp.MustCompile(`^(msg|message|item|chatcmpl)_[A-Za-z0-9_-]{1,256}$`) +) + +// ClassifyOpenAIPreviousResponseIDKind classifies previous_response_id to improve diagnostics. +func ClassifyOpenAIPreviousResponseIDKind(id string) string { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return OpenAIPreviousResponseIDKindEmpty + } + if openAIResponseIDPattern.MatchString(trimmed) { + return OpenAIPreviousResponseIDKindResponseID + } + if openAIMessageIDPattern.MatchString(strings.ToLower(trimmed)) { + return OpenAIPreviousResponseIDKindMessageID + } + return OpenAIPreviousResponseIDKindUnknown +} + +func IsOpenAIPreviousResponseIDLikelyMessageID(id string) bool { + return ClassifyOpenAIPreviousResponseIDKind(id) == OpenAIPreviousResponseIDKindMessageID +} diff --git a/backend/internal/service/openai_previous_response_id_test.go b/backend/internal/service/openai_previous_response_id_test.go new file mode 100644 index 00000000..7867b864 --- /dev/null +++ b/backend/internal/service/openai_previous_response_id_test.go @@ -0,0 +1,34 @@ +package service + +import "testing" + +func TestClassifyOpenAIPreviousResponseIDKind(t *testing.T) { + tests := []struct { + name string + id string + want string + }{ + {name: "empty", id: " ", want: OpenAIPreviousResponseIDKindEmpty}, + {name: "response_id", id: "resp_0906a621bc423a8d0169a108637ef88197b74b0e2f37ba358f", want: OpenAIPreviousResponseIDKindResponseID}, + {name: "message_id", id: "msg_123456", want: OpenAIPreviousResponseIDKindMessageID}, + {name: "item_id", id: "item_abcdef", want: OpenAIPreviousResponseIDKindMessageID}, + {name: "unknown", id: "foo_123456", want: OpenAIPreviousResponseIDKindUnknown}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := ClassifyOpenAIPreviousResponseIDKind(tc.id); got != tc.want { + t.Fatalf("ClassifyOpenAIPreviousResponseIDKind(%q)=%q want=%q", tc.id, got, tc.want) + } + }) + } +} + +func TestIsOpenAIPreviousResponseIDLikelyMessageID(t *testing.T) { + if !IsOpenAIPreviousResponseIDLikelyMessageID("msg_123") { + t.Fatal("expected msg_123 to be identified as message id") + } + if IsOpenAIPreviousResponseIDLikelyMessageID("resp_123") { + t.Fatal("expected resp_123 not to be identified as message id") + } +} diff --git a/backend/internal/service/openai_sticky_compat.go b/backend/internal/service/openai_sticky_compat.go new file mode 100644 index 00000000..e897debc --- /dev/null +++ b/backend/internal/service/openai_sticky_compat.go @@ -0,0 +1,214 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "sync/atomic" + "time" + + "github.com/cespare/xxhash/v2" + "github.com/gin-gonic/gin" +) + +type openAILegacySessionHashContextKey struct{} + +var openAILegacySessionHashKey = openAILegacySessionHashContextKey{} + +var ( + openAIStickyLegacyReadFallbackTotal atomic.Int64 + openAIStickyLegacyReadFallbackHit atomic.Int64 + openAIStickyLegacyDualWriteTotal atomic.Int64 +) + +func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal int64) { + return openAIStickyLegacyReadFallbackTotal.Load(), + openAIStickyLegacyReadFallbackHit.Load(), + openAIStickyLegacyDualWriteTotal.Load() +} + +func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) { + normalized := strings.TrimSpace(sessionID) + if normalized == "" { + return "", "" + } + + currentHash = fmt.Sprintf("%016x", xxhash.Sum64String(normalized)) + sum := sha256.Sum256([]byte(normalized)) + legacyHash = hex.EncodeToString(sum[:]) + return currentHash, legacyHash +} + +func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context { + if ctx == nil { + return nil + } + trimmed := strings.TrimSpace(legacyHash) + if trimmed == "" { + return ctx + } + return context.WithValue(ctx, openAILegacySessionHashKey, trimmed) +} + +func openAILegacySessionHashFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + value, _ := ctx.Value(openAILegacySessionHashKey).(string) + return strings.TrimSpace(value) +} + +func attachOpenAILegacySessionHashToGin(c *gin.Context, legacyHash string) { + if c == nil || c.Request == nil { + return + } + c.Request = c.Request.WithContext(withOpenAILegacySessionHash(c.Request.Context(), legacyHash)) +} + +func (s *OpenAIGatewayService) openAISessionHashReadOldFallbackEnabled() bool { + if s == nil || s.cfg == nil { + return true + } + return s.cfg.Gateway.OpenAIWS.SessionHashReadOldFallback +} + +func (s *OpenAIGatewayService) openAISessionHashDualWriteOldEnabled() bool { + if s == nil || s.cfg == nil { + return true + } + return s.cfg.Gateway.OpenAIWS.SessionHashDualWriteOld +} + +func (s *OpenAIGatewayService) openAISessionCacheKey(sessionHash string) string { + normalized := strings.TrimSpace(sessionHash) + if normalized == "" { + return "" + } + return "openai:" + normalized +} + +func (s *OpenAIGatewayService) openAILegacySessionCacheKey(ctx context.Context, sessionHash string) string { + legacyHash := openAILegacySessionHashFromContext(ctx) + if legacyHash == "" { + return "" + } + legacyKey := "openai:" + legacyHash + if legacyKey == s.openAISessionCacheKey(sessionHash) { + return "" + } + return legacyKey +} + +func (s *OpenAIGatewayService) openAIStickyLegacyTTL(ttl time.Duration) time.Duration { + legacyTTL := ttl + if legacyTTL <= 0 { + legacyTTL = openaiStickySessionTTL + } + if legacyTTL > 10*time.Minute { + return 10 * time.Minute + } + return legacyTTL +} + +func (s *OpenAIGatewayService) getStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) { + if s == nil || s.cache == nil { + return 0, nil + } + + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return 0, nil + } + + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), primaryKey) + if err == nil && accountID > 0 { + return accountID, nil + } + if !s.openAISessionHashReadOldFallbackEnabled() { + return accountID, err + } + + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey == "" { + return accountID, err + } + + openAIStickyLegacyReadFallbackTotal.Add(1) + legacyAccountID, legacyErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), legacyKey) + if legacyErr == nil && legacyAccountID > 0 { + openAIStickyLegacyReadFallbackHit.Add(1) + return legacyAccountID, nil + } + return accountID, err +} + +func (s *OpenAIGatewayService) setStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string, accountID int64, ttl time.Duration) error { + if s == nil || s.cache == nil || accountID <= 0 { + return nil + } + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return nil + } + + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), primaryKey, accountID, ttl); err != nil { + return err + } + + if !s.openAISessionHashDualWriteOldEnabled() { + return nil + } + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey == "" { + return nil + } + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), legacyKey, accountID, s.openAIStickyLegacyTTL(ttl)); err != nil { + return err + } + openAIStickyLegacyDualWriteTotal.Add(1) + return nil +} + +func (s *OpenAIGatewayService) refreshStickySessionTTL(ctx context.Context, groupID *int64, sessionHash string, ttl time.Duration) error { + if s == nil || s.cache == nil { + return nil + } + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return nil + } + + err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), primaryKey, ttl) + if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() { + return err + } + + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey != "" { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), legacyKey, s.openAIStickyLegacyTTL(ttl)) + } + return err +} + +func (s *OpenAIGatewayService) deleteStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) error { + if s == nil || s.cache == nil { + return nil + } + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return nil + } + + err := s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), primaryKey) + if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() { + return err + } + + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey != "" { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), legacyKey) + } + return err +} diff --git a/backend/internal/service/openai_sticky_compat_test.go b/backend/internal/service/openai_sticky_compat_test.go new file mode 100644 index 00000000..9f57c358 --- /dev/null +++ b/backend/internal/service/openai_sticky_compat_test.go @@ -0,0 +1,96 @@ +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestGetStickySessionAccountID_FallbackToLegacyKey(t *testing.T) { + beforeFallbackTotal, beforeFallbackHit, _ := openAIStickyCompatStats() + + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:legacy-hash": 42, + }, + } + svc := &OpenAIGatewayService{ + cache: cache, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + SessionHashReadOldFallback: true, + }, + }, + }, + } + + ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash") + accountID, err := svc.getStickySessionAccountID(ctx, nil, "new-hash") + require.NoError(t, err) + require.Equal(t, int64(42), accountID) + + afterFallbackTotal, afterFallbackHit, _ := openAIStickyCompatStats() + require.Equal(t, beforeFallbackTotal+1, afterFallbackTotal) + require.Equal(t, beforeFallbackHit+1, afterFallbackHit) +} + +func TestSetStickySessionAccountID_DualWriteOldEnabled(t *testing.T) { + _, _, beforeDualWriteTotal := openAIStickyCompatStats() + + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + svc := &OpenAIGatewayService{ + cache: cache, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + SessionHashDualWriteOld: true, + }, + }, + }, + } + + ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash") + err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL) + require.NoError(t, err) + require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"]) + require.Equal(t, int64(9), cache.sessionBindings["openai:legacy-hash"]) + + _, _, afterDualWriteTotal := openAIStickyCompatStats() + require.Equal(t, beforeDualWriteTotal+1, afterDualWriteTotal) +} + +func TestSetStickySessionAccountID_DualWriteOldDisabled(t *testing.T) { + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + svc := &OpenAIGatewayService{ + cache: cache, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + SessionHashDualWriteOld: false, + }, + }, + }, + } + + ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash") + err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL) + require.NoError(t, err) + require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"]) + _, exists := cache.sessionBindings["openai:legacy-hash"] + require.False(t, exists) +} + +func TestSnapshotOpenAICompatibilityFallbackMetrics(t *testing.T) { + before := SnapshotOpenAICompatibilityFallbackMetrics() + + ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true) + _, _ = ThinkingEnabledFromContext(ctx) + + after := SnapshotOpenAICompatibilityFallbackMetrics() + require.GreaterOrEqual(t, after.MetadataLegacyFallbackTotal, before.MetadataLegacyFallbackTotal+1) + require.GreaterOrEqual(t, after.MetadataLegacyFallbackThinkingEnabledTotal, before.MetadataLegacyFallbackThinkingEnabledTotal+1) +} diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 87a7713b..a8a6b96c 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -4,16 +4,74 @@ import ( "context" "errors" "log/slog" + "math/rand/v2" "strings" + "sync/atomic" "time" ) const ( - openAITokenRefreshSkew = 3 * time.Minute - openAITokenCacheSkew = 5 * time.Minute - openAILockWaitTime = 200 * time.Millisecond + openAITokenRefreshSkew = 3 * time.Minute + openAITokenCacheSkew = 5 * time.Minute + openAILockInitialWait = 20 * time.Millisecond + openAILockMaxWait = 120 * time.Millisecond + openAILockMaxAttempts = 5 + openAILockJitterRatio = 0.2 + openAILockWarnThresholdMs = 250 ) +// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。 +type OpenAITokenRuntimeMetrics struct { + RefreshRequests int64 + RefreshSuccess int64 + RefreshFailure int64 + LockAcquireFailure int64 + LockContention int64 + LockWaitSamples int64 + LockWaitTotalMs int64 + LockWaitHit int64 + LockWaitMiss int64 + LastObservedUnixMs int64 +} + +type openAITokenRuntimeMetricsStore struct { + refreshRequests atomic.Int64 + refreshSuccess atomic.Int64 + refreshFailure atomic.Int64 + lockAcquireFailure atomic.Int64 + lockContention atomic.Int64 + lockWaitSamples atomic.Int64 + lockWaitTotalMs atomic.Int64 + lockWaitHit atomic.Int64 + lockWaitMiss atomic.Int64 + lastObservedUnixMs atomic.Int64 +} + +func (m *openAITokenRuntimeMetricsStore) snapshot() OpenAITokenRuntimeMetrics { + if m == nil { + return OpenAITokenRuntimeMetrics{} + } + return OpenAITokenRuntimeMetrics{ + RefreshRequests: m.refreshRequests.Load(), + RefreshSuccess: m.refreshSuccess.Load(), + RefreshFailure: m.refreshFailure.Load(), + LockAcquireFailure: m.lockAcquireFailure.Load(), + LockContention: m.lockContention.Load(), + LockWaitSamples: m.lockWaitSamples.Load(), + LockWaitTotalMs: m.lockWaitTotalMs.Load(), + LockWaitHit: m.lockWaitHit.Load(), + LockWaitMiss: m.lockWaitMiss.Load(), + LastObservedUnixMs: m.lastObservedUnixMs.Load(), + } +} + +func (m *openAITokenRuntimeMetricsStore) touchNow() { + if m == nil { + return + } + m.lastObservedUnixMs.Store(time.Now().UnixMilli()) +} + // OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) type OpenAITokenCache = GeminiTokenCache @@ -22,6 +80,7 @@ type OpenAITokenProvider struct { accountRepo AccountRepository tokenCache OpenAITokenCache openAIOAuthService *OpenAIOAuthService + metrics *openAITokenRuntimeMetricsStore } func NewOpenAITokenProvider( @@ -33,16 +92,32 @@ func NewOpenAITokenProvider( accountRepo: accountRepo, tokenCache: tokenCache, openAIOAuthService: openAIOAuthService, + metrics: &openAITokenRuntimeMetricsStore{}, + } +} + +func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics { + if p == nil { + return OpenAITokenRuntimeMetrics{} + } + p.ensureMetrics() + return p.metrics.snapshot() +} + +func (p *OpenAITokenProvider) ensureMetrics() { + if p != nil && p.metrics == nil { + p.metrics = &openAITokenRuntimeMetricsStore{} } } // GetAccessToken 获取有效的 access_token func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { + p.ensureMetrics() if account == nil { return "", errors.New("account is nil") } - if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { - return "", errors.New("not an openai oauth account") + if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth { + return "", errors.New("not an openai/sora oauth account") } cacheKey := OpenAITokenCacheKey(account) @@ -64,6 +139,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew refreshFailed := false if needsRefresh && p.tokenCache != nil { + p.metrics.refreshRequests.Add(1) + p.metrics.touchNow() locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) if lockErr == nil && locked { defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() @@ -80,16 +157,23 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } expiresAt = account.GetCredentialAsTime("expires_at") if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if p.openAIOAuthService == nil { + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID) + // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 + refreshFailed = true + } else if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) + p.metrics.refreshFailure.Add(1) refreshFailed = true // 无法刷新,标记失败 } else { tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) if err != nil { // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) + p.metrics.refreshFailure.Add(1) refreshFailed = true // 刷新失败,标记以使用短 TTL } else { + p.metrics.refreshSuccess.Add(1) newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) for k, v := range account.Credentials { if _, exists := newCredentials[k]; !exists { @@ -106,6 +190,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } else if lockErr != nil { // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时) + p.metrics.lockAcquireFailure.Add(1) + p.metrics.touchNow() slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr) // 检查 ctx 是否已取消 @@ -124,15 +210,22 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou // 仅在 expires_at 已过期/接近过期时才执行无锁刷新 if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if p.openAIOAuthService == nil { + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID) + // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 + refreshFailed = true + } else if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) + p.metrics.refreshFailure.Add(1) refreshFailed = true } else { tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) if err != nil { slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err) + p.metrics.refreshFailure.Add(1) refreshFailed = true } else { + p.metrics.refreshSuccess.Add(1) newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) for k, v := range account.Credentials { if _, exists := newCredentials[k]; !exists { @@ -148,16 +241,21 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } } else { - // 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存 - time.Sleep(openAILockWaitTime) - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + // 锁被其他 worker 持有:使用短轮询+jitter,降低固定等待导致的尾延迟台阶。 + p.metrics.lockContention.Add(1) + p.metrics.touchNow() + token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey) + if waitErr != nil { + return "", waitErr + } + if strings.TrimSpace(token) != "" { slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) return token, nil } } } - accessToken := account.GetOpenAIAccessToken() + accessToken := account.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found in credentials") } @@ -198,3 +296,64 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou return accessToken, nil } + +func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) { + wait := openAILockInitialWait + totalWaitMs := int64(0) + for i := 0; i < openAILockMaxAttempts; i++ { + actualWait := jitterLockWait(wait) + timer := time.NewTimer(actualWait) + select { + case <-ctx.Done(): + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + return "", ctx.Err() + case <-timer.C: + } + + waitMs := actualWait.Milliseconds() + if waitMs < 0 { + waitMs = 0 + } + totalWaitMs += waitMs + p.metrics.lockWaitSamples.Add(1) + p.metrics.lockWaitTotalMs.Add(waitMs) + p.metrics.touchNow() + + token, err := p.tokenCache.GetAccessToken(ctx, cacheKey) + if err == nil && strings.TrimSpace(token) != "" { + p.metrics.lockWaitHit.Add(1) + if totalWaitMs >= openAILockWarnThresholdMs { + slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", i+1) + } + return token, nil + } + + if wait < openAILockMaxWait { + wait *= 2 + if wait > openAILockMaxWait { + wait = openAILockMaxWait + } + } + } + + p.metrics.lockWaitMiss.Add(1) + if totalWaitMs >= openAILockWarnThresholdMs { + slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", openAILockMaxAttempts) + } + return "", nil +} + +func jitterLockWait(base time.Duration) time.Duration { + if base <= 0 { + return 0 + } + minFactor := 1 - openAILockJitterRatio + maxFactor := 1 + openAILockJitterRatio + factor := minFactor + rand.Float64()*(maxFactor-minFactor) + return time.Duration(float64(base) * factor) +} diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go index c2e3dbb0..1cd92367 100644 --- a/backend/internal/service/openai_token_provider_test.go +++ b/backend/internal/service/openai_token_provider_test.go @@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai oauth account") + require.Contains(t, err.Error(), "not an openai/sora oauth account") require.Empty(t, token) } @@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai oauth account") + require.Contains(t, err.Error(), "not an openai/sora oauth account") require.Empty(t, token) } @@ -808,3 +808,119 @@ func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) { require.Contains(t, err.Error(), "access_token not found") require.Empty(t, token) } + +func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false // 模拟锁被其他 worker 持有 + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 207, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + cacheKey := OpenAITokenCacheKey(account) + go func() { + time.Sleep(5 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "winner-token" + cache.mu.Unlock() + }() + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "winner-token", token) +} + +func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false // 模拟锁被其他 worker 持有 + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 208, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + provider := NewOpenAITokenProvider(nil, cache, nil) + start := time.Now() + token, err := provider.GetAccessToken(ctx, account) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + require.Empty(t, token) + require.Less(t, time.Since(start), 50*time.Millisecond) +} + +func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 209, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + cacheKey := OpenAITokenCacheKey(account) + go func() { + time.Sleep(10 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "winner-token" + cache.mu.Unlock() + }() + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "winner-token", token) + + metrics := provider.SnapshotRuntimeMetrics() + require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1)) + require.GreaterOrEqual(t, metrics.LockContention, int64(1)) + require.GreaterOrEqual(t, metrics.LockWaitSamples, int64(1)) + require.GreaterOrEqual(t, metrics.LockWaitHit, int64(1)) + require.GreaterOrEqual(t, metrics.LockWaitTotalMs, int64(0)) + require.GreaterOrEqual(t, metrics.LastObservedUnixMs, int64(1)) +} + +func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockErr = errors.New("redis lock error") + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 210, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + _, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + + metrics := provider.SnapshotRuntimeMetrics() + require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1)) + require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1)) +} diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go index e59082b2..dea3c172 100644 --- a/backend/internal/service/openai_tool_continuation.go +++ b/backend/internal/service/openai_tool_continuation.go @@ -2,6 +2,24 @@ package service import "strings" +// ToolContinuationSignals 聚合工具续链相关信号,避免重复遍历 input。 +type ToolContinuationSignals struct { + HasFunctionCallOutput bool + HasFunctionCallOutputMissingCallID bool + HasToolCallContext bool + HasItemReference bool + HasItemReferenceForAllCallIDs bool + FunctionCallOutputCallIDs []string +} + +// FunctionCallOutputValidation 汇总 function_call_output 关联性校验结果。 +type FunctionCallOutputValidation struct { + HasFunctionCallOutput bool + HasToolCallContext bool + HasFunctionCallOutputMissingCallID bool + HasItemReferenceForAllCallIDs bool +} + // NeedsToolContinuation 判定请求是否需要工具调用续链处理。 // 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、 // 或显式声明 tools/tool_choice。 @@ -18,107 +36,191 @@ func NeedsToolContinuation(reqBody map[string]any) bool { if hasToolChoiceSignal(reqBody) { return true } - if inputHasType(reqBody, "function_call_output") { - return true + input, ok := reqBody["input"].([]any) + if !ok { + return false } - if inputHasType(reqBody, "item_reference") { - return true + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType == "function_call_output" || itemType == "item_reference" { + return true + } } return false } +// AnalyzeToolContinuationSignals 单次遍历 input,提取 function_call_output/tool_call/item_reference 相关信号。 +func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals { + signals := ToolContinuationSignals{} + if reqBody == nil { + return signals + } + input, ok := reqBody["input"].([]any) + if !ok { + return signals + } + + var callIDs map[string]struct{} + var referenceIDs map[string]struct{} + + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + switch itemType { + case "tool_call", "function_call": + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) != "" { + signals.HasToolCallContext = true + } + case "function_call_output": + signals.HasFunctionCallOutput = true + callID, _ := itemMap["call_id"].(string) + callID = strings.TrimSpace(callID) + if callID == "" { + signals.HasFunctionCallOutputMissingCallID = true + continue + } + if callIDs == nil { + callIDs = make(map[string]struct{}) + } + callIDs[callID] = struct{}{} + case "item_reference": + signals.HasItemReference = true + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + if referenceIDs == nil { + referenceIDs = make(map[string]struct{}) + } + referenceIDs[idValue] = struct{}{} + } + } + + if len(callIDs) == 0 { + return signals + } + signals.FunctionCallOutputCallIDs = make([]string, 0, len(callIDs)) + allReferenced := len(referenceIDs) > 0 + for callID := range callIDs { + signals.FunctionCallOutputCallIDs = append(signals.FunctionCallOutputCallIDs, callID) + if allReferenced { + if _, ok := referenceIDs[callID]; !ok { + allReferenced = false + } + } + } + signals.HasItemReferenceForAllCallIDs = allReferenced + return signals +} + +// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果: +// 1) 无 function_call_output 直接返回 +// 2) 若已存在 tool_call/function_call 上下文则提前返回 +// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合 +func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation { + result := FunctionCallOutputValidation{} + if reqBody == nil { + return result + } + input, ok := reqBody["input"].([]any) + if !ok { + return result + } + + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + switch itemType { + case "function_call_output": + result.HasFunctionCallOutput = true + case "tool_call", "function_call": + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) != "" { + result.HasToolCallContext = true + } + } + if result.HasFunctionCallOutput && result.HasToolCallContext { + return result + } + } + + if !result.HasFunctionCallOutput || result.HasToolCallContext { + return result + } + + callIDs := make(map[string]struct{}) + referenceIDs := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + switch itemType { + case "function_call_output": + callID, _ := itemMap["call_id"].(string) + callID = strings.TrimSpace(callID) + if callID == "" { + result.HasFunctionCallOutputMissingCallID = true + continue + } + callIDs[callID] = struct{}{} + case "item_reference": + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + referenceIDs[idValue] = struct{}{} + } + } + + if len(callIDs) == 0 || len(referenceIDs) == 0 { + return result + } + allReferenced := true + for callID := range callIDs { + if _, ok := referenceIDs[callID]; !ok { + allReferenced = false + break + } + } + result.HasItemReferenceForAllCallIDs = allReferenced + return result +} + // HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。 func HasFunctionCallOutput(reqBody map[string]any) bool { - if reqBody == nil { - return false - } - return inputHasType(reqBody, "function_call_output") + return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput } // HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call, // 用于判断 function_call_output 是否具备可关联的上下文。 func HasToolCallContext(reqBody map[string]any) bool { - if reqBody == nil { - return false - } - input, ok := reqBody["input"].([]any) - if !ok { - return false - } - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType != "tool_call" && itemType != "function_call" { - continue - } - if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { - return true - } - } - return false + return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext } // FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。 // 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。 func FunctionCallOutputCallIDs(reqBody map[string]any) []string { - if reqBody == nil { - return nil - } - input, ok := reqBody["input"].([]any) - if !ok { - return nil - } - ids := make(map[string]struct{}) - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType != "function_call_output" { - continue - } - if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { - ids[callID] = struct{}{} - } - } - if len(ids) == 0 { - return nil - } - result := make([]string, 0, len(ids)) - for id := range ids { - result = append(result, id) - } - return result + return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs } // HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。 func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { - if reqBody == nil { - return false - } - input, ok := reqBody["input"].([]any) - if !ok { - return false - } - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType != "function_call_output" { - continue - } - callID, _ := itemMap["call_id"].(string) - if strings.TrimSpace(callID) == "" { - return true - } - } - return false + return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID } // HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。 @@ -152,32 +254,13 @@ func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool { return false } for _, callID := range callIDs { - if _, ok := referenceIDs[callID]; !ok { + if _, ok := referenceIDs[strings.TrimSpace(callID)]; !ok { return false } } return true } -// inputHasType 判断 input 中是否存在指定类型的 item。 -func inputHasType(reqBody map[string]any, want string) bool { - input, ok := reqBody["input"].([]any) - if !ok { - return false - } - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType == want { - return true - } - } - return false -} - // hasNonEmptyString 判断字段是否为非空字符串。 func hasNonEmptyString(value any) bool { stringValue, ok := value.(string) diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go index f4719275..348723a6 100644 --- a/backend/internal/service/openai_tool_corrector.go +++ b/backend/internal/service/openai_tool_corrector.go @@ -1,10 +1,15 @@ package service import ( - "encoding/json" + "bytes" "fmt" - "log" + "strconv" + "strings" "sync" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射 @@ -61,237 +66,273 @@ func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, boo if data == "" || data == "\n" { return data, false } + correctedBytes, corrected := c.CorrectToolCallsInSSEBytes([]byte(data)) + if !corrected { + return data, false + } + return string(correctedBytes), true +} - // 尝试解析 JSON - var payload map[string]any - if err := json.Unmarshal([]byte(data), &payload); err != nil { - // 不是有效的 JSON,直接返回原数据 +// CorrectToolCallsInSSEBytes 修正 SSE JSON 数据中的工具调用(字节路径)。 +// 返回修正后的数据和是否进行了修正。 +func (c *CodexToolCorrector) CorrectToolCallsInSSEBytes(data []byte) ([]byte, bool) { + if len(bytes.TrimSpace(data)) == 0 { + return data, false + } + if !mayContainToolCallPayload(data) { + return data, false + } + if !gjson.ValidBytes(data) { + // 不是有效 JSON,直接返回原数据 return data, false } + updated := data corrected := false - - // 处理 tool_calls 数组 - if toolCalls, ok := payload["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { + collect := func(changed bool, next []byte) { + if changed { corrected = true + updated = next } } - // 处理 function_call 对象 - if functionCall, ok := payload["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } + if next, changed := c.correctToolCallsArrayAtPath(updated, "tool_calls"); changed { + collect(changed, next) + } + if next, changed := c.correctFunctionAtPath(updated, "function_call"); changed { + collect(changed, next) + } + if next, changed := c.correctToolCallsArrayAtPath(updated, "delta.tool_calls"); changed { + collect(changed, next) + } + if next, changed := c.correctFunctionAtPath(updated, "delta.function_call"); changed { + collect(changed, next) } - // 处理 delta.tool_calls - if delta, ok := payload["delta"].(map[string]any); ok { - if toolCalls, ok := delta["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { - corrected = true - } + choicesCount := int(gjson.GetBytes(updated, "choices.#").Int()) + for i := 0; i < choicesCount; i++ { + prefix := "choices." + strconv.Itoa(i) + if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".message.tool_calls"); changed { + collect(changed, next) } - if functionCall, ok := delta["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } + if next, changed := c.correctFunctionAtPath(updated, prefix+".message.function_call"); changed { + collect(changed, next) } - } - - // 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls - if choices, ok := payload["choices"].([]any); ok { - for _, choice := range choices { - if choiceMap, ok := choice.(map[string]any); ok { - // 处理 message 中的工具调用 - if message, ok := choiceMap["message"].(map[string]any); ok { - if toolCalls, ok := message["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { - corrected = true - } - } - if functionCall, ok := message["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } - } - } - // 处理 delta 中的工具调用 - if delta, ok := choiceMap["delta"].(map[string]any); ok { - if toolCalls, ok := delta["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { - corrected = true - } - } - if functionCall, ok := delta["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } - } - } - } + if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".delta.tool_calls"); changed { + collect(changed, next) + } + if next, changed := c.correctFunctionAtPath(updated, prefix+".delta.function_call"); changed { + collect(changed, next) } } if !corrected { return data, false } + return updated, true +} - // 序列化回 JSON - correctedBytes, err := json.Marshal(payload) - if err != nil { - log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err) +func mayContainToolCallPayload(data []byte) bool { + // 快速路径:多数 token / 文本事件不包含工具字段,避免进入 JSON 解析热路径。 + return bytes.Contains(data, []byte(`"tool_calls"`)) || + bytes.Contains(data, []byte(`"function_call"`)) || + bytes.Contains(data, []byte(`"function":{"name"`)) +} + +// correctToolCallsArrayAtPath 修正指定路径下 tool_calls 数组中的工具名称。 +func (c *CodexToolCorrector) correctToolCallsArrayAtPath(data []byte, toolCallsPath string) ([]byte, bool) { + count := int(gjson.GetBytes(data, toolCallsPath+".#").Int()) + if count <= 0 { return data, false } - - return string(correctedBytes), true -} - -// correctToolCallsArray 修正工具调用数组中的工具名称 -func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool { + updated := data corrected := false - for _, toolCall := range toolCalls { - if toolCallMap, ok := toolCall.(map[string]any); ok { - if function, ok := toolCallMap["function"].(map[string]any); ok { - if c.correctFunctionCall(function) { - corrected = true - } - } + for i := 0; i < count; i++ { + functionPath := toolCallsPath + "." + strconv.Itoa(i) + ".function" + if next, changed := c.correctFunctionAtPath(updated, functionPath); changed { + updated = next + corrected = true } } - return corrected + return updated, corrected } -// correctFunctionCall 修正单个函数调用的工具名称和参数 -func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool { - name, ok := functionCall["name"].(string) - if !ok || name == "" { - return false +// correctFunctionAtPath 修正指定路径下单个函数调用的工具名称和参数。 +func (c *CodexToolCorrector) correctFunctionAtPath(data []byte, functionPath string) ([]byte, bool) { + namePath := functionPath + ".name" + nameResult := gjson.GetBytes(data, namePath) + if !nameResult.Exists() || nameResult.Type != gjson.String { + return data, false } - + name := strings.TrimSpace(nameResult.Str) + if name == "" { + return data, false + } + updated := data corrected := false // 查找并修正工具名称 if correctName, found := codexToolNameMapping[name]; found { - functionCall["name"] = correctName - c.recordCorrection(name, correctName) - corrected = true - name = correctName // 使用修正后的名称进行参数修正 + if next, err := sjson.SetBytes(updated, namePath, correctName); err == nil { + updated = next + c.recordCorrection(name, correctName) + corrected = true + name = correctName // 使用修正后的名称进行参数修正 + } } // 修正工具参数(基于工具名称) - if c.correctToolParameters(name, functionCall) { + if next, changed := c.correctToolParametersAtPath(updated, functionPath+".arguments", name); changed { + updated = next corrected = true } - - return corrected + return updated, corrected } -// correctToolParameters 修正工具参数以符合 OpenCode 规范 -func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool { - arguments, ok := functionCall["arguments"] - if !ok { - return false +// correctToolParametersAtPath 修正指定路径下 arguments 参数。 +func (c *CodexToolCorrector) correctToolParametersAtPath(data []byte, argumentsPath, toolName string) ([]byte, bool) { + if toolName != "bash" && toolName != "edit" { + return data, false } - // arguments 可能是字符串(JSON)或已解析的 map - var argsMap map[string]any - switch v := arguments.(type) { - case string: - // 解析 JSON 字符串 - if err := json.Unmarshal([]byte(v), &argsMap); err != nil { - return false + args := gjson.GetBytes(data, argumentsPath) + if !args.Exists() { + return data, false + } + + switch args.Type { + case gjson.String: + argsJSON := strings.TrimSpace(args.Str) + if !gjson.Valid(argsJSON) { + return data, false } - case map[string]any: - argsMap = v + if !gjson.Parse(argsJSON).IsObject() { + return data, false + } + nextArgsJSON, corrected := c.correctToolArgumentsJSON(argsJSON, toolName) + if !corrected { + return data, false + } + next, err := sjson.SetBytes(data, argumentsPath, nextArgsJSON) + if err != nil { + return data, false + } + return next, true + case gjson.JSON: + if !args.IsObject() || !gjson.Valid(args.Raw) { + return data, false + } + nextArgsJSON, corrected := c.correctToolArgumentsJSON(args.Raw, toolName) + if !corrected { + return data, false + } + next, err := sjson.SetRawBytes(data, argumentsPath, []byte(nextArgsJSON)) + if err != nil { + return data, false + } + return next, true default: - return false + return data, false + } +} + +// correctToolArgumentsJSON 修正工具参数 JSON(对象字符串),返回修正后的 JSON 与是否变更。 +func (c *CodexToolCorrector) correctToolArgumentsJSON(argsJSON, toolName string) (string, bool) { + if !gjson.Valid(argsJSON) { + return argsJSON, false + } + if !gjson.Parse(argsJSON).IsObject() { + return argsJSON, false } + updated := argsJSON corrected := false // 根据工具名称应用特定的参数修正规则 switch toolName { case "bash": // OpenCode bash 支持 workdir;有些来源会输出 work_dir。 - if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir { - if workDir, exists := argsMap["work_dir"]; exists { - argsMap["workdir"] = workDir - delete(argsMap, "work_dir") + if !gjson.Get(updated, "workdir").Exists() { + if next, changed := moveJSONField(updated, "work_dir", "workdir"); changed { + updated = next corrected = true - log.Printf("[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool") } } else { - if _, exists := argsMap["work_dir"]; exists { - delete(argsMap, "work_dir") + if next, changed := deleteJSONField(updated, "work_dir"); changed { + updated = next corrected = true - log.Printf("[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool") } } case "edit": // OpenCode edit 参数为 filePath/oldString/newString(camelCase)。 - if _, exists := argsMap["filePath"]; !exists { - if filePath, exists := argsMap["file_path"]; exists { - argsMap["filePath"] = filePath - delete(argsMap, "file_path") + if !gjson.Get(updated, "filePath").Exists() { + if next, changed := moveJSONField(updated, "file_path", "filePath"); changed { + updated = next corrected = true - log.Printf("[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool") - } else if filePath, exists := argsMap["path"]; exists { - argsMap["filePath"] = filePath - delete(argsMap, "path") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool") + } else if next, changed := moveJSONField(updated, "path", "filePath"); changed { + updated = next corrected = true - log.Printf("[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool") - } else if filePath, exists := argsMap["file"]; exists { - argsMap["filePath"] = filePath - delete(argsMap, "file") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool") + } else if next, changed := moveJSONField(updated, "file", "filePath"); changed { + updated = next corrected = true - log.Printf("[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool") } } - if _, exists := argsMap["oldString"]; !exists { - if oldString, exists := argsMap["old_string"]; exists { - argsMap["oldString"] = oldString - delete(argsMap, "old_string") - corrected = true - log.Printf("[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool") - } + if next, changed := moveJSONField(updated, "old_string", "oldString"); changed { + updated = next + corrected = true + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool") } - if _, exists := argsMap["newString"]; !exists { - if newString, exists := argsMap["new_string"]; exists { - argsMap["newString"] = newString - delete(argsMap, "new_string") - corrected = true - log.Printf("[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool") - } + if next, changed := moveJSONField(updated, "new_string", "newString"); changed { + updated = next + corrected = true + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool") } - if _, exists := argsMap["replaceAll"]; !exists { - if replaceAll, exists := argsMap["replace_all"]; exists { - argsMap["replaceAll"] = replaceAll - delete(argsMap, "replace_all") - corrected = true - log.Printf("[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool") - } + if next, changed := moveJSONField(updated, "replace_all", "replaceAll"); changed { + updated = next + corrected = true + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool") } } + return updated, corrected +} - // 如果修正了参数,需要重新序列化 - if corrected { - if _, wasString := arguments.(string); wasString { - // 原本是字符串,序列化回字符串 - if newArgsJSON, err := json.Marshal(argsMap); err == nil { - functionCall["arguments"] = string(newArgsJSON) - } - } else { - // 原本是 map,直接赋值 - functionCall["arguments"] = argsMap - } +func moveJSONField(input, from, to string) (string, bool) { + if gjson.Get(input, to).Exists() { + return input, false } + src := gjson.Get(input, from) + if !src.Exists() { + return input, false + } + next, err := sjson.SetRaw(input, to, src.Raw) + if err != nil { + return input, false + } + next, err = sjson.Delete(next, from) + if err != nil { + return input, false + } + return next, true +} - return corrected +func deleteJSONField(input, path string) (string, bool) { + if !gjson.Get(input, path).Exists() { + return input, false + } + next, err := sjson.Delete(input, path) + if err != nil { + return input, false + } + return next, true } // recordCorrection 记录一次工具名称修正 @@ -303,7 +344,7 @@ func (c *CodexToolCorrector) recordCorrection(from, to string) { key := fmt.Sprintf("%s->%s", from, to) c.stats.CorrectionsByTool[key]++ - log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)", + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)", from, to, c.stats.TotalCorrected) } diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go index ff518ea6..7c83de9e 100644 --- a/backend/internal/service/openai_tool_corrector_test.go +++ b/backend/internal/service/openai_tool_corrector_test.go @@ -5,6 +5,15 @@ import ( "testing" ) +func TestMayContainToolCallPayload(t *testing.T) { + if mayContainToolCallPayload([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)) { + t.Fatalf("plain text event should not trigger tool-call parsing") + } + if !mayContainToolCallPayload([]byte(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)) { + t.Fatalf("tool_calls event should trigger tool-call parsing") + } +} + func TestCorrectToolCallsInSSEData(t *testing.T) { corrector := NewCodexToolCorrector() diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go new file mode 100644 index 00000000..3fe08179 --- /dev/null +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -0,0 +1,190 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + account := Account{ + ID: 2, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.True(t, selection.Acquired) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + account := Account{ + ID: 8, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}}) + require.NoError(t, err) + require.Nil(t, selection) +} + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + account := Account{ + ID: 11, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_ws_force_http": true, + "responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil) + require.NoError(t, err) + require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连") +} + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + accounts := []Account{ + { + ID: 21, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + { + ID: 22, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 9, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2 + cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second + + concurrencyCache := stubConcurrencyCache{ + acquireResults: map[int64]bool{ + 21: false, // previous_response 命中的账号繁忙 + 22: true, // 次优账号可用(若回退会命中) + }, + waitCounts: map[int64]int{ + 21: 999, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(21), selection.Account.ID, "busy previous_response sticky account should remain selected") + require.False(t, selection.Acquired) + require.NotNil(t, selection.WaitPlan) + require.Equal(t, int64(21), selection.WaitPlan.AccountID) +} + +func newOpenAIWSV2TestConfig() *config.Config { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + return cfg +} diff --git a/backend/internal/service/openai_ws_client.go b/backend/internal/service/openai_ws_client.go new file mode 100644 index 00000000..9f3c47b7 --- /dev/null +++ b/backend/internal/service/openai_ws_client.go @@ -0,0 +1,285 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + coderws "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" +) + +const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024 +const ( + openAIWSProxyTransportMaxIdleConns = 128 + openAIWSProxyTransportMaxIdleConnsPerHost = 64 + openAIWSProxyTransportIdleConnTimeout = 90 * time.Second + openAIWSProxyClientCacheMaxEntries = 256 + openAIWSProxyClientCacheIdleTTL = 15 * time.Minute +) + +type OpenAIWSTransportMetricsSnapshot struct { + ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"` + ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"` + TransportReuseRatio float64 `json:"transport_reuse_ratio"` +} + +// openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。 +type openAIWSClientConn interface { + WriteJSON(ctx context.Context, value any) error + ReadMessage(ctx context.Context) ([]byte, error) + Ping(ctx context.Context) error + Close() error +} + +// openAIWSClientDialer 抽象 WS 建连器。 +type openAIWSClientDialer interface { + Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error) +} + +type openAIWSTransportMetricsDialer interface { + SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot +} + +func newDefaultOpenAIWSClientDialer() openAIWSClientDialer { + return &coderOpenAIWSClientDialer{ + proxyClients: make(map[string]*openAIWSProxyClientEntry), + } +} + +type coderOpenAIWSClientDialer struct { + proxyMu sync.Mutex + proxyClients map[string]*openAIWSProxyClientEntry + proxyHits atomic.Int64 + proxyMisses atomic.Int64 +} + +type openAIWSProxyClientEntry struct { + client *http.Client + lastUsedUnixNano int64 +} + +func (d *coderOpenAIWSClientDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + targetURL := strings.TrimSpace(wsURL) + if targetURL == "" { + return nil, 0, nil, errors.New("ws url is empty") + } + + opts := &coderws.DialOptions{ + HTTPHeader: cloneHeader(headers), + CompressionMode: coderws.CompressionContextTakeover, + } + if proxy := strings.TrimSpace(proxyURL); proxy != "" { + proxyClient, err := d.proxyHTTPClient(proxy) + if err != nil { + return nil, 0, nil, err + } + opts.HTTPClient = proxyClient + } + + conn, resp, err := coderws.Dial(ctx, targetURL, opts) + if err != nil { + status := 0 + respHeaders := http.Header(nil) + if resp != nil { + status = resp.StatusCode + respHeaders = cloneHeader(resp.Header) + } + return nil, status, respHeaders, err + } + // coder/websocket 默认单消息读取上限为 32KB,Codex WS 事件(如 rate_limits/大 delta) + // 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。 + conn.SetReadLimit(openAIWSMessageReadLimitBytes) + respHeaders := http.Header(nil) + if resp != nil { + respHeaders = cloneHeader(resp.Header) + } + return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil +} + +func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) { + if d == nil { + return nil, errors.New("openai ws dialer is nil") + } + normalizedProxy := strings.TrimSpace(proxy) + if normalizedProxy == "" { + return nil, errors.New("proxy url is empty") + } + parsedProxyURL, err := url.Parse(normalizedProxy) + if err != nil { + return nil, fmt.Errorf("invalid proxy url: %w", err) + } + now := time.Now().UnixNano() + + d.proxyMu.Lock() + defer d.proxyMu.Unlock() + if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil { + entry.lastUsedUnixNano = now + d.proxyHits.Add(1) + return entry.client, nil + } + d.cleanupProxyClientsLocked(now) + transport := &http.Transport{ + Proxy: http.ProxyURL(parsedProxyURL), + MaxIdleConns: openAIWSProxyTransportMaxIdleConns, + MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost, + IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout, + TLSHandshakeTimeout: 10 * time.Second, + ForceAttemptHTTP2: true, + } + client := &http.Client{Transport: transport} + d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{ + client: client, + lastUsedUnixNano: now, + } + d.ensureProxyClientCapacityLocked() + d.proxyMisses.Add(1) + return client, nil +} + +func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) { + if d == nil || len(d.proxyClients) == 0 { + return + } + idleTTL := openAIWSProxyClientCacheIdleTTL + if idleTTL <= 0 { + return + } + now := time.Unix(0, nowUnixNano) + for key, entry := range d.proxyClients { + if entry == nil || entry.client == nil { + delete(d.proxyClients, key) + continue + } + lastUsed := time.Unix(0, entry.lastUsedUnixNano) + if now.Sub(lastUsed) > idleTTL { + closeOpenAIWSProxyClient(entry.client) + delete(d.proxyClients, key) + } + } +} + +func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() { + if d == nil { + return + } + maxEntries := openAIWSProxyClientCacheMaxEntries + if maxEntries <= 0 { + return + } + for len(d.proxyClients) > maxEntries { + var oldestKey string + var oldestLastUsed int64 + hasOldest := false + for key, entry := range d.proxyClients { + lastUsed := int64(0) + if entry != nil { + lastUsed = entry.lastUsedUnixNano + } + if !hasOldest || lastUsed < oldestLastUsed { + hasOldest = true + oldestKey = key + oldestLastUsed = lastUsed + } + } + if !hasOldest { + return + } + if entry := d.proxyClients[oldestKey]; entry != nil { + closeOpenAIWSProxyClient(entry.client) + } + delete(d.proxyClients, oldestKey) + } +} + +func closeOpenAIWSProxyClient(client *http.Client) { + if client == nil || client.Transport == nil { + return + } + if transport, ok := client.Transport.(*http.Transport); ok && transport != nil { + transport.CloseIdleConnections() + } +} + +func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot { + if d == nil { + return OpenAIWSTransportMetricsSnapshot{} + } + hits := d.proxyHits.Load() + misses := d.proxyMisses.Load() + total := hits + misses + reuseRatio := 0.0 + if total > 0 { + reuseRatio = float64(hits) / float64(total) + } + return OpenAIWSTransportMetricsSnapshot{ + ProxyClientCacheHits: hits, + ProxyClientCacheMisses: misses, + TransportReuseRatio: reuseRatio, + } +} + +type coderOpenAIWSClientConn struct { + conn *coderws.Conn +} + +func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return wsjson.Write(ctx, c.conn, value) +} + +func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) { + if c == nil || c.conn == nil { + return nil, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + + msgType, payload, err := c.conn.Read(ctx) + if err != nil { + return nil, err + } + switch msgType { + case coderws.MessageText, coderws.MessageBinary: + return payload, nil + default: + return nil, errOpenAIWSConnClosed + } +} + +func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Ping(ctx) +} + +func (c *coderOpenAIWSClientConn) Close() error { + if c == nil || c.conn == nil { + return nil + } + // Close 为幂等,忽略重复关闭错误。 + _ = c.conn.Close(coderws.StatusNormalClosure, "") + _ = c.conn.CloseNow() + return nil +} diff --git a/backend/internal/service/openai_ws_client_test.go b/backend/internal/service/openai_ws_client_test.go new file mode 100644 index 00000000..a88d6266 --- /dev/null +++ b/backend/internal/service/openai_ws_client_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestCoderOpenAIWSClientDialer_ProxyHTTPClientReuse(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + c1, err := impl.proxyHTTPClient("http://127.0.0.1:8080") + require.NoError(t, err) + c2, err := impl.proxyHTTPClient("http://127.0.0.1:8080") + require.NoError(t, err) + require.Same(t, c1, c2, "同一代理地址应复用同一个 HTTP 客户端") + + c3, err := impl.proxyHTTPClient("http://127.0.0.1:8081") + require.NoError(t, err) + require.NotSame(t, c1, c3, "不同代理地址应分离客户端") +} + +func TestCoderOpenAIWSClientDialer_ProxyHTTPClientInvalidURL(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + _, err := impl.proxyHTTPClient("://bad") + require.Error(t, err) +} + +func TestCoderOpenAIWSClientDialer_TransportMetricsSnapshot(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + _, err := impl.proxyHTTPClient("http://127.0.0.1:18080") + require.NoError(t, err) + _, err = impl.proxyHTTPClient("http://127.0.0.1:18080") + require.NoError(t, err) + _, err = impl.proxyHTTPClient("http://127.0.0.1:18081") + require.NoError(t, err) + + snapshot := impl.SnapshotTransportMetrics() + require.Equal(t, int64(1), snapshot.ProxyClientCacheHits) + require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses) + require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001) +} + +func TestCoderOpenAIWSClientDialer_ProxyClientCacheCapacity(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + total := openAIWSProxyClientCacheMaxEntries + 32 + for i := 0; i < total; i++ { + _, err := impl.proxyHTTPClient(fmt.Sprintf("http://127.0.0.1:%d", 20000+i)) + require.NoError(t, err) + } + + impl.proxyMu.Lock() + cacheSize := len(impl.proxyClients) + impl.proxyMu.Unlock() + + require.LessOrEqual(t, cacheSize, openAIWSProxyClientCacheMaxEntries, "代理客户端缓存应受容量上限约束") +} + +func TestCoderOpenAIWSClientDialer_ProxyClientCacheIdleTTL(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + oldProxy := "http://127.0.0.1:28080" + _, err := impl.proxyHTTPClient(oldProxy) + require.NoError(t, err) + + impl.proxyMu.Lock() + oldEntry := impl.proxyClients[oldProxy] + require.NotNil(t, oldEntry) + oldEntry.lastUsedUnixNano = time.Now().Add(-openAIWSProxyClientCacheIdleTTL - time.Minute).UnixNano() + impl.proxyMu.Unlock() + + // 触发一次新的代理获取,驱动 TTL 清理。 + _, err = impl.proxyHTTPClient("http://127.0.0.1:28081") + require.NoError(t, err) + + impl.proxyMu.Lock() + _, exists := impl.proxyClients[oldProxy] + impl.proxyMu.Unlock() + + require.False(t, exists, "超过空闲 TTL 的代理客户端应被回收") +} + +func TestCoderOpenAIWSClientDialer_ProxyTransportTLSHandshakeTimeout(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + client, err := impl.proxyHTTPClient("http://127.0.0.1:38080") + require.NoError(t, err) + require.NotNil(t, client) + + transport, ok := client.Transport.(*http.Transport) + require.True(t, ok) + require.NotNil(t, transport) + require.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout) +} diff --git a/backend/internal/service/openai_ws_fallback_test.go b/backend/internal/service/openai_ws_fallback_test.go new file mode 100644 index 00000000..ce06f6a2 --- /dev/null +++ b/backend/internal/service/openai_ws_fallback_test.go @@ -0,0 +1,251 @@ +package service + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" +) + +func TestClassifyOpenAIWSAcquireError(t *testing.T) { + t.Run("dial_426_upgrade_required", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 426, Err: errors.New("upgrade required")} + require.Equal(t, "upgrade_required", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("queue_full", func(t *testing.T) { + require.Equal(t, "conn_queue_full", classifyOpenAIWSAcquireError(errOpenAIWSConnQueueFull)) + }) + + t.Run("preferred_conn_unavailable", func(t *testing.T) { + require.Equal(t, "preferred_conn_unavailable", classifyOpenAIWSAcquireError(errOpenAIWSPreferredConnUnavailable)) + }) + + t.Run("acquire_timeout", func(t *testing.T) { + require.Equal(t, "acquire_timeout", classifyOpenAIWSAcquireError(context.DeadlineExceeded)) + }) + + t.Run("auth_failed_401", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 401, Err: errors.New("unauthorized")} + require.Equal(t, "auth_failed", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("upstream_rate_limited", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 429, Err: errors.New("rate limited")} + require.Equal(t, "upstream_rate_limited", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("upstream_5xx", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 502, Err: errors.New("bad gateway")} + require.Equal(t, "upstream_5xx", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("dial_failed_other_status", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 418, Err: errors.New("teapot")} + require.Equal(t, "dial_failed", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("other", func(t *testing.T) { + require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(errors.New("x"))) + }) + + t.Run("nil", func(t *testing.T) { + require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(nil)) + }) +} + +func TestClassifyOpenAIWSDialError(t *testing.T) { + t.Run("handshake_not_finished", func(t *testing.T) { + err := &openAIWSDialError{ + StatusCode: http.StatusBadGateway, + Err: errors.New("WebSocket protocol error: Handshake not finished"), + } + require.Equal(t, "handshake_not_finished", classifyOpenAIWSDialError(err)) + }) + + t.Run("context_deadline", func(t *testing.T) { + err := &openAIWSDialError{ + StatusCode: 0, + Err: context.DeadlineExceeded, + } + require.Equal(t, "ctx_deadline_exceeded", classifyOpenAIWSDialError(err)) + }) +} + +func TestSummarizeOpenAIWSDialError(t *testing.T) { + err := &openAIWSDialError{ + StatusCode: http.StatusBadGateway, + ResponseHeaders: http.Header{ + "Server": []string{"cloudflare"}, + "Via": []string{"1.1 example"}, + "Cf-Ray": []string{"abcd1234"}, + "X-Request-Id": []string{"req_123"}, + }, + Err: errors.New("WebSocket protocol error: Handshake not finished"), + } + + status, class, closeStatus, closeReason, server, via, cfRay, reqID := summarizeOpenAIWSDialError(err) + require.Equal(t, http.StatusBadGateway, status) + require.Equal(t, "handshake_not_finished", class) + require.Equal(t, "-", closeStatus) + require.Equal(t, "-", closeReason) + require.Equal(t, "cloudflare", server) + require.Equal(t, "1.1 example", via) + require.Equal(t, "abcd1234", cfRay) + require.Equal(t, "req_123", reqID) +} + +func TestClassifyOpenAIWSErrorEvent(t *testing.T) { + reason, recoverable := classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"upgrade_required","message":"Upgrade required"}}`)) + require.Equal(t, "upgrade_required", reason) + require.True(t, recoverable) + + reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`)) + require.Equal(t, "previous_response_not_found", reason) + require.True(t, recoverable) +} + +func TestClassifyOpenAIWSReconnectReason(t *testing.T) { + reason, retryable := classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("policy_violation", errors.New("policy"))) + require.Equal(t, "policy_violation", reason) + require.False(t, retryable) + + reason, retryable = classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("read_event", errors.New("io"))) + require.Equal(t, "read_event", reason) + require.True(t, retryable) +} + +func TestOpenAIWSErrorHTTPStatus(t *testing.T) { + require.Equal(t, http.StatusBadRequest, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`))) + require.Equal(t, http.StatusUnauthorized, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"authentication_error","code":"invalid_api_key","message":"auth failed"}}`))) + require.Equal(t, http.StatusForbidden, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"permission_error","code":"forbidden","message":"forbidden"}}`))) + require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"rate_limit_error","code":"rate_limit_exceeded","message":"rate limited"}}`))) + require.Equal(t, http.StatusBadGateway, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"server_error","code":"server_error","message":"server"}}`))) +} + +func TestResolveOpenAIWSFallbackErrorResponse(t *testing.T) { + t.Run("previous_response_not_found", func(t *testing.T) { + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse( + wrapOpenAIWSFallback("previous_response_not_found", errors.New("previous response not found")), + ) + require.True(t, ok) + require.Equal(t, http.StatusBadRequest, statusCode) + require.Equal(t, "invalid_request_error", errType) + require.Equal(t, "previous response not found", clientMessage) + require.Equal(t, "previous response not found", upstreamMessage) + }) + + t.Run("auth_failed_uses_dial_status", func(t *testing.T) { + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse( + wrapOpenAIWSFallback("auth_failed", &openAIWSDialError{ + StatusCode: http.StatusForbidden, + Err: errors.New("forbidden"), + }), + ) + require.True(t, ok) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, "upstream_error", errType) + require.Equal(t, "forbidden", clientMessage) + require.Equal(t, "forbidden", upstreamMessage) + }) + + t.Run("non_fallback_error_not_resolved", func(t *testing.T) { + _, _, _, _, ok := resolveOpenAIWSFallbackErrorResponse(errors.New("plain error")) + require.False(t, ok) + }) +} + +func TestOpenAIWSFallbackCooling(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + require.False(t, svc.isOpenAIWSFallbackCooling(1)) + svc.markOpenAIWSFallbackCooling(1, "upgrade_required") + require.True(t, svc.isOpenAIWSFallbackCooling(1)) + + svc.clearOpenAIWSFallbackCooling(1) + require.False(t, svc.isOpenAIWSFallbackCooling(1)) + + svc.markOpenAIWSFallbackCooling(2, "x") + time.Sleep(1200 * time.Millisecond) + require.False(t, svc.isOpenAIWSFallbackCooling(2)) +} + +func TestOpenAIWSRetryBackoff(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 100 + svc.cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 400 + svc.cfg.Gateway.OpenAIWS.RetryJitterRatio = 0 + + require.Equal(t, time.Duration(100)*time.Millisecond, svc.openAIWSRetryBackoff(1)) + require.Equal(t, time.Duration(200)*time.Millisecond, svc.openAIWSRetryBackoff(2)) + require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(3)) + require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(4)) +} + +func TestOpenAIWSRetryTotalBudget(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 1200 + require.Equal(t, 1200*time.Millisecond, svc.openAIWSRetryTotalBudget()) + + svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 0 + require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget()) +} + +func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) { + require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation})) + require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig})) + require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io"))) +} + +func TestOpenAIWSStoreDisabledConnMode(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true + require.Equal(t, openAIWSStoreDisabledConnModeStrict, svc.openAIWSStoreDisabledConnMode()) + + svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "adaptive" + require.Equal(t, openAIWSStoreDisabledConnModeAdaptive, svc.openAIWSStoreDisabledConnMode()) + + svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "" + svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false + require.Equal(t, openAIWSStoreDisabledConnModeOff, svc.openAIWSStoreDisabledConnMode()) +} + +func TestShouldForceNewConnOnStoreDisabled(t *testing.T) { + require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeStrict, "")) + require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeOff, "policy_violation")) + + require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "policy_violation")) + require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "prewarm_message_too_big")) + require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "read_event")) +} + +func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) { + svc := &OpenAIGatewayService{} + svc.recordOpenAIWSRetryAttempt(150 * time.Millisecond) + svc.recordOpenAIWSRetryAttempt(0) + svc.recordOpenAIWSRetryExhausted() + svc.recordOpenAIWSNonRetryableFastFallback() + + snapshot := svc.SnapshotOpenAIWSRetryMetrics() + require.Equal(t, int64(2), snapshot.RetryAttemptsTotal) + require.Equal(t, int64(150), snapshot.RetryBackoffMsTotal) + require.Equal(t, int64(1), snapshot.RetryExhaustedTotal) + require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal) +} + +func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + + svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 0 + require.True(t, svc.shouldLogOpenAIWSPayloadSchema(1), "首次尝试应始终记录 payload_schema") + require.False(t, svc.shouldLogOpenAIWSPayloadSchema(2)) + + svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 1 + require.True(t, svc.shouldLogOpenAIWSPayloadSchema(2)) +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go new file mode 100644 index 00000000..74ba472f --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder.go @@ -0,0 +1,3955 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "net/url" + "sort" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +const ( + openAIWSBetaV1Value = "responses_websockets=2026-02-04" + openAIWSBetaV2Value = "responses_websockets=2026-02-06" + + openAIWSTurnStateHeader = "x-codex-turn-state" + openAIWSTurnMetadataHeader = "x-codex-turn-metadata" + + openAIWSLogValueMaxLen = 160 + openAIWSHeaderValueMaxLen = 120 + openAIWSIDValueMaxLen = 64 + openAIWSEventLogHeadLimit = 20 + openAIWSEventLogEveryN = 50 + openAIWSBufferLogHeadLimit = 8 + openAIWSBufferLogEveryN = 20 + openAIWSPrewarmEventLogHead = 10 + openAIWSPayloadKeySizeTopN = 6 + + openAIWSPayloadSizeEstimateDepth = 3 + openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024 + openAIWSPayloadSizeEstimateMaxItems = 16 + + openAIWSEventFlushBatchSizeDefault = 4 + openAIWSEventFlushIntervalDefault = 25 * time.Millisecond + openAIWSPayloadLogSampleDefault = 0.2 + + openAIWSStoreDisabledConnModeStrict = "strict" + openAIWSStoreDisabledConnModeAdaptive = "adaptive" + openAIWSStoreDisabledConnModeOff = "off" + + openAIWSIngressStagePreviousResponseNotFound = "previous_response_not_found" + openAIWSMaxPrevResponseIDDeletePasses = 8 +) + +var openAIWSLogValueReplacer = strings.NewReplacer( + "error", "err", + "fallback", "fb", + "warning", "warnx", + "failed", "fail", +) + +var openAIWSIngressPreflightPingIdle = 20 * time.Second + +// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。 +type openAIWSFallbackError struct { + Reason string + Err error +} + +func (e *openAIWSFallbackError) Error() string { + if e == nil { + return "" + } + if e.Err == nil { + return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason)) + } + return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err) +} + +func (e *openAIWSFallbackError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +func wrapOpenAIWSFallback(reason string, err error) error { + return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err} +} + +// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。 +type OpenAIWSClientCloseError struct { + statusCode coderws.StatusCode + reason string + err error +} + +type openAIWSIngressTurnError struct { + stage string + cause error + wroteDownstream bool +} + +func (e *openAIWSIngressTurnError) Error() string { + if e == nil { + return "" + } + if e.cause == nil { + return strings.TrimSpace(e.stage) + } + return e.cause.Error() +} + +func (e *openAIWSIngressTurnError) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + +func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error { + if cause == nil { + return nil + } + return &openAIWSIngressTurnError{ + stage: strings.TrimSpace(stage), + cause: cause, + wroteDownstream: wroteDownstream, + } +} + +func isOpenAIWSIngressTurnRetryable(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) { + return false + } + if turnErr.wroteDownstream { + return false + } + switch turnErr.stage { + case "write_upstream", "read_upstream": + return true + default: + return false + } +} + +func openAIWSIngressTurnRetryReason(err error) string { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return "unknown" + } + if turnErr.stage == "" { + return "unknown" + } + return turnErr.stage +} + +func isOpenAIWSIngressPreviousResponseNotFound(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound { + return false + } + return !turnErr.wroteDownstream +} + +// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。 +func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error { + return &OpenAIWSClientCloseError{ + statusCode: statusCode, + reason: strings.TrimSpace(reason), + err: err, + } +} + +func (e *OpenAIWSClientCloseError) Error() string { + if e == nil { + return "" + } + if e.err == nil { + return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason)) + } + return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err) +} + +func (e *OpenAIWSClientCloseError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode { + if e == nil { + return coderws.StatusInternalError + } + return e.statusCode +} + +func (e *OpenAIWSClientCloseError) Reason() string { + if e == nil { + return "" + } + return strings.TrimSpace(e.reason) +} + +// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。 +type OpenAIWSIngressHooks struct { + BeforeTurn func(turn int) error + AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) +} + +func normalizeOpenAIWSLogValue(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "-" + } + return openAIWSLogValueReplacer.Replace(trimmed) +} + +func truncateOpenAIWSLogValue(value string, maxLen int) string { + normalized := normalizeOpenAIWSLogValue(value) + if normalized == "-" || maxLen <= 0 { + return normalized + } + if len(normalized) <= maxLen { + return normalized + } + return normalized[:maxLen] + "..." +} + +func openAIWSHeaderValueForLog(headers http.Header, key string) string { + if headers == nil { + return "-" + } + return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen) +} + +func hasOpenAIWSHeader(headers http.Header, key string) bool { + if headers == nil { + return false + } + return strings.TrimSpace(headers.Get(key)) != "" +} + +type openAIWSSessionHeaderResolution struct { + SessionID string + ConversationID string + SessionSource string + ConversationSource string +} + +func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution { + resolution := openAIWSSessionHeaderResolution{ + SessionSource: "none", + ConversationSource: "none", + } + if c != nil && c.Request != nil { + if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" { + resolution.SessionID = sessionID + resolution.SessionSource = "header_session_id" + } + if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" { + resolution.ConversationID = conversationID + resolution.ConversationSource = "header_conversation_id" + if resolution.SessionID == "" { + resolution.SessionID = conversationID + resolution.SessionSource = "header_conversation_id" + } + } + } + + cacheKey := strings.TrimSpace(promptCacheKey) + if cacheKey != "" { + if resolution.SessionID == "" { + resolution.SessionID = cacheKey + resolution.SessionSource = "prompt_cache_key" + } + } + return resolution +} + +func shouldLogOpenAIWSEvent(idx int, eventType string) bool { + if idx <= openAIWSEventLogHeadLimit { + return true + } + if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 { + return true + } + if eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + return true + } + return false +} + +func shouldLogOpenAIWSBufferedEvent(idx int) bool { + if idx <= openAIWSBufferLogHeadLimit { + return true + } + if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 { + return true + } + return false +} + +func openAIWSEventMayContainModel(eventType string) bool { + switch eventType { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + trimmed := strings.TrimSpace(eventType) + if trimmed == eventType { + return false + } + switch trimmed { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + return false + } + } +} + +func openAIWSEventMayContainToolCalls(eventType string) bool { + eventType = strings.TrimSpace(eventType) + if eventType == "" { + return false + } + if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") { + return true + } + switch eventType { + case "response.output_item.added", "response.output_item.done", "response.completed", "response.done": + return true + default: + return false + } +} + +func openAIWSEventShouldParseUsage(eventType string) bool { + return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed" +} + +func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) { + if len(message) == 0 { + return "", "", gjson.Result{} + } + values := gjson.GetManyBytes(message, "type", "response.id", "id", "response") + eventType = strings.TrimSpace(values[0].String()) + if id := strings.TrimSpace(values[1].String()); id != "" { + responseID = id + } else { + responseID = strings.TrimSpace(values[2].String()) + } + return eventType, responseID, values[3] +} + +func openAIWSMessageLikelyContainsToolCalls(message []byte) bool { + if len(message) == 0 { + return false + } + return bytes.Contains(message, []byte(`"tool_calls"`)) || + bytes.Contains(message, []byte(`"tool_call"`)) || + bytes.Contains(message, []byte(`"function_call"`)) +} + +func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) { + if usage == nil || len(message) == 0 { + return + } + values := gjson.GetManyBytes( + message, + "response.usage.input_tokens", + "response.usage.output_tokens", + "response.usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "", "", "" + } + values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message") + return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String()) +} + +func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) { + code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen) + errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen) + errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen) + return code, errType, errMessage +} + +func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "-", "-", "-" + } + return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string { + if len(payload) == 0 { + return "-" + } + type keySize struct { + Key string + Size int + } + sizes := make([]keySize, 0, len(payload)) + for key, value := range payload { + size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth) + sizes = append(sizes, keySize{Key: key, Size: size}) + } + sort.Slice(sizes, func(i, j int) bool { + if sizes[i].Size == sizes[j].Size { + return sizes[i].Key < sizes[j].Key + } + return sizes[i].Size > sizes[j].Size + }) + + if topN <= 0 || topN > len(sizes) { + topN = len(sizes) + } + parts := make([]string, 0, topN) + for idx := 0; idx < topN; idx++ { + item := sizes[idx] + parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size)) + } + return strings.Join(parts, ",") +} + +func estimateOpenAIWSPayloadValueSize(value any, depth int) int { + if depth <= 0 { + return -1 + } + switch v := value.(type) { + case nil: + return 0 + case string: + return len(v) + case []byte: + return len(v) + case bool: + return 1 + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return 8 + case float32, float64: + return 8 + case map[string]any: + if len(v) == 0 { + return 2 + } + total := 2 + count := 0 + for key, item := range v { + count++ + if count > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1) + if itemSize < 0 { + return -1 + } + total += len(key) + itemSize + 3 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + case []any: + if len(v) == 0 { + return 2 + } + total := 2 + limit := len(v) + if limit > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + for i := 0; i < limit; i++ { + itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1) + if itemSize < 0 { + return -1 + } + total += itemSize + 1 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + default: + raw, err := json.Marshal(v) + if err != nil { + return -1 + } + if len(raw) > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + return len(raw) + } +} + +func openAIWSPayloadString(payload map[string]any, key string) string { + if len(payload) == 0 { + return "" + } + raw, ok := payload[key] + if !ok { + return "" + } + switch v := raw.(type) { + case nil: + return "" + case string: + return strings.TrimSpace(v) + case []byte: + return strings.TrimSpace(string(v)) + default: + return "" + } +} + +func openAIWSPayloadStringFromRaw(payload []byte, key string) string { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return "" + } + return strings.TrimSpace(gjson.GetBytes(payload, key).String()) +} + +func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return defaultValue + } + value := gjson.GetBytes(payload, key) + if !value.Exists() { + return defaultValue + } + if value.Type != gjson.True && value.Type != gjson.False { + return defaultValue + } + return value.Bool() +} + +func openAIWSSessionHashesFromID(sessionID string) (string, string) { + return deriveOpenAISessionHashes(sessionID) +} + +func extractOpenAIWSImageURL(value any) string { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) + case map[string]any: + if raw, ok := v["url"].(string); ok { + return strings.TrimSpace(raw) + } + } + return "" +} + +func summarizeOpenAIWSInput(input any) string { + items, ok := input.([]any) + if !ok || len(items) == 0 { + return "-" + } + + itemCount := len(items) + textChars := 0 + imageDataURLs := 0 + imageDataURLChars := 0 + imageRemoteURLs := 0 + + handleContentItem := func(contentItem map[string]any) { + contentType, _ := contentItem["type"].(string) + switch strings.TrimSpace(contentType) { + case "input_text", "output_text", "text": + if text, ok := contentItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(contentItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + handleInputItem := func(inputItem map[string]any) { + if content, ok := inputItem["content"].([]any); ok { + for _, rawContent := range content { + contentItem, ok := rawContent.(map[string]any) + if !ok { + continue + } + handleContentItem(contentItem) + } + return + } + + itemType, _ := inputItem["type"].(string) + switch strings.TrimSpace(itemType) { + case "input_text", "output_text", "text": + if text, ok := inputItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(inputItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + for _, rawItem := range items { + inputItem, ok := rawItem.(map[string]any) + if !ok { + continue + } + handleInputItem(inputItem) + } + + return fmt.Sprintf( + "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d", + itemCount, + textChars, + imageDataURLs, + imageDataURLChars, + imageRemoteURLs, + ) +} + +func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return + } + if _, exists := payload[key]; !exists { + return + } + delete(payload, key) + *removed = append(*removed, key) +} + +// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段, +// 避免重试成功却改变原始请求语义。 +// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。 +func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) { + if len(payload) == 0 { + return "empty", nil + } + if attempt <= 1 { + return "full", nil + } + + removed := make([]string, 0, 2) + if attempt >= 2 { + dropOpenAIWSPayloadKey(payload, "include", &removed) + } + + if len(removed) == 0 { + return "full", nil + } + sort.Strings(removed) + return "trim_optional_fields", removed +} + +func logOpenAIWSModeInfo(format string, args ...any) { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func isOpenAIWSModeDebugEnabled() bool { + return logger.L().Core().Enabled(zap.DebugLevel) +} + +func logOpenAIWSModeDebug(format string, args ...any) { + if !isOpenAIWSModeDebugEnabled() { + return + } + logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) { + if err == nil { + return + } + logger.L().Warn( + "openai.ws_bind_response_account_failed", + zap.Int64("group_id", groupID), + zap.Int64("account_id", accountID), + zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)), + zap.Error(err), + ) +} + +func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) { + if err == nil { + return "-", "-" + } + statusCode := coderws.CloseStatus(err) + if statusCode == -1 { + return "-", "-" + } + closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) + closeReason := "-" + var closeErr coderws.CloseError + if errors.As(err, &closeErr) { + reasonText := strings.TrimSpace(closeErr.Reason) + if reasonText != "" { + closeReason = normalizeOpenAIWSLogValue(reasonText) + } + } + return normalizeOpenAIWSLogValue(closeStatus), closeReason +} + +func unwrapOpenAIWSDialBaseError(err error) error { + if err == nil { + return nil + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil { + return dialErr.Err + } + return err +} + +func openAIWSDialRespHeaderForLog(err error, key string) string { + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil { + return "-" + } + return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen) +} + +func classifyOpenAIWSDialError(err error) string { + if err == nil { + return "-" + } + baseErr := unwrapOpenAIWSDialBaseError(err) + if baseErr == nil { + return "-" + } + if errors.Is(baseErr, context.DeadlineExceeded) { + return "ctx_deadline_exceeded" + } + if errors.Is(baseErr, context.Canceled) { + return "ctx_canceled" + } + var netErr net.Error + if errors.As(baseErr, &netErr) && netErr.Timeout() { + return "net_timeout" + } + if status := coderws.CloseStatus(baseErr); status != -1 { + return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status))) + } + message := strings.ToLower(strings.TrimSpace(baseErr.Error())) + switch { + case strings.Contains(message, "handshake not finished"): + return "handshake_not_finished" + case strings.Contains(message, "bad handshake"): + return "bad_handshake" + case strings.Contains(message, "connection refused"): + return "connection_refused" + case strings.Contains(message, "no such host"): + return "dns_not_found" + case strings.Contains(message, "tls"): + return "tls_error" + case strings.Contains(message, "i/o timeout"): + return "io_timeout" + case strings.Contains(message, "context deadline exceeded"): + return "ctx_deadline_exceeded" + default: + return "dial_error" + } +} + +func summarizeOpenAIWSDialError(err error) ( + statusCode int, + dialClass string, + closeStatus string, + closeReason string, + respServer string, + respVia string, + respCFRay string, + respRequestID string, +) { + dialClass = "-" + closeStatus = "-" + closeReason = "-" + respServer = "-" + respVia = "-" + respCFRay = "-" + respRequestID = "-" + if err == nil { + return + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil { + statusCode = dialErr.StatusCode + respServer = openAIWSDialRespHeaderForLog(err, "server") + respVia = openAIWSDialRespHeaderForLog(err, "via") + respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray") + respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id") + } + dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err)) + closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err)) + return +} + +func isOpenAIWSClientDisconnectError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return true + } + switch coderws.CloseStatus(err) { + case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + if message == "" { + return false + } + return strings.Contains(message, "failed to read frame header: eof") || + strings.Contains(message, "unexpected eof") || + strings.Contains(message, "use of closed network connection") || + strings.Contains(message, "connection reset by peer") || + strings.Contains(message, "broken pipe") +} + +func classifyOpenAIWSReadFallbackReason(err error) string { + if err == nil { + return "read_event" + } + switch coderws.CloseStatus(err) { + case coderws.StatusPolicyViolation: + return "policy_violation" + case coderws.StatusMessageTooBig: + return "message_too_big" + default: + return "read_event" + } +} + +func sortedKeys(m map[string]any) []string { + if len(m) == 0 { + return nil + } + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + +func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool { + if s == nil { + return nil + } + s.openaiWSPoolOnce.Do(func() { + if s.openaiWSPool == nil { + s.openaiWSPool = newOpenAIWSConnPool(s.cfg) + } + }) + return s.openaiWSPool +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot { + pool := s.getOpenAIWSConnPool() + if pool == nil { + return OpenAIWSPoolMetricsSnapshot{} + } + return pool.SnapshotMetrics() +} + +type OpenAIWSPerformanceMetricsSnapshot struct { + Pool OpenAIWSPoolMetricsSnapshot `json:"pool"` + Retry OpenAIWSRetryMetricsSnapshot `json:"retry"` + Transport OpenAIWSTransportMetricsSnapshot `json:"transport"` +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot { + pool := s.getOpenAIWSConnPool() + snapshot := OpenAIWSPerformanceMetricsSnapshot{ + Retry: s.SnapshotOpenAIWSRetryMetrics(), + } + if pool == nil { + return snapshot + } + snapshot.Pool = pool.SnapshotMetrics() + snapshot.Transport = pool.SnapshotTransportMetrics() + return snapshot +} + +func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore { + if s == nil { + return nil + } + s.openaiWSStateStoreOnce.Do(func() { + if s.openaiWSStateStore == nil { + s.openaiWSStateStore = NewOpenAIWSStateStore(s.cache) + } + }) + return s.openaiWSStateStore +} + +func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration { + if s != nil && s.cfg != nil { + seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds + if seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + return time.Hour +} + +func (s *OpenAIGatewayService) openAIWSIngressPreviousResponseRecoveryEnabled() bool { + if s != nil && s.cfg != nil { + return s.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled + } + return true +} + +func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds) * time.Second + } + return 15 * time.Minute +} + +func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second + } + return 2 * time.Minute +} + +func (s *OpenAIGatewayService) openAIWSEventFlushBatchSize() int { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushBatchSize > 0 { + return s.cfg.Gateway.OpenAIWS.EventFlushBatchSize + } + return openAIWSEventFlushBatchSizeDefault +} + +func (s *OpenAIGatewayService) openAIWSEventFlushInterval() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS >= 0 { + if s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS == 0 { + return 0 + } + return time.Duration(s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS) * time.Millisecond + } + return openAIWSEventFlushIntervalDefault +} + +func (s *OpenAIGatewayService) openAIWSPayloadLogSampleRate() float64 { + if s != nil && s.cfg != nil { + rate := s.cfg.Gateway.OpenAIWS.PayloadLogSampleRate + if rate < 0 { + return 0 + } + if rate > 1 { + return 1 + } + return rate + } + return openAIWSPayloadLogSampleDefault +} + +func (s *OpenAIGatewayService) shouldLogOpenAIWSPayloadSchema(attempt int) bool { + // 首次尝试保留一条完整 payload_schema 便于排障。 + if attempt <= 1 { + return true + } + rate := s.openAIWSPayloadLogSampleRate() + if rate <= 0 { + return false + } + if rate >= 1 { + return true + } + return rand.Float64() < rate +} + +func (s *OpenAIGatewayService) shouldEmitOpenAIWSPayloadSchema(attempt int) bool { + if !s.shouldLogOpenAIWSPayloadSchema(attempt) { + return false + } + return logger.L().Core().Enabled(zap.DebugLevel) +} + +func (s *OpenAIGatewayService) openAIWSDialTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second + } + return 10 * time.Second +} + +func (s *OpenAIGatewayService) openAIWSAcquireTimeout() time.Duration { + // Acquire 覆盖“连接复用命中/排队/新建连接”三个阶段。 + // 这里不再叠加 write_timeout,避免高并发排队时把 TTFT 长尾拉到分钟级。 + dial := s.openAIWSDialTimeout() + if dial <= 0 { + dial = 10 * time.Second + } + return dial + 2*time.Second +} + +func (s *OpenAIGatewayService) buildOpenAIResponsesWSURL(account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + var targetURL string + switch account.Type { + case AccountTypeOAuth: + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + targetURL = openaiPlatformAPIURL + } else { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return "", err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + default: + targetURL = openaiPlatformAPIURL + } + + parsed, err := url.Parse(strings.TrimSpace(targetURL)) + if err != nil { + return "", fmt.Errorf("invalid target url: %w", err) + } + switch strings.ToLower(parsed.Scheme) { + case "https": + parsed.Scheme = "wss" + case "http": + parsed.Scheme = "ws" + case "wss", "ws": + // 保持不变 + default: + return "", fmt.Errorf("unsupported scheme for ws: %s", parsed.Scheme) + } + return parsed.String(), nil +} + +func (s *OpenAIGatewayService) buildOpenAIWSHeaders( + c *gin.Context, + account *Account, + token string, + decision OpenAIWSProtocolDecision, + isCodexCLI bool, + turnState string, + turnMetadata string, + promptCacheKey string, +) (http.Header, openAIWSSessionHeaderResolution) { + headers := make(http.Header) + headers.Set("authorization", "Bearer "+token) + + sessionResolution := resolveOpenAIWSSessionHeaders(c, promptCacheKey) + if c != nil && c.Request != nil { + if v := strings.TrimSpace(c.Request.Header.Get("accept-language")); v != "" { + headers.Set("accept-language", v) + } + } + if sessionResolution.SessionID != "" { + headers.Set("session_id", sessionResolution.SessionID) + } + if sessionResolution.ConversationID != "" { + headers.Set("conversation_id", sessionResolution.ConversationID) + } + if state := strings.TrimSpace(turnState); state != "" { + headers.Set(openAIWSTurnStateHeader, state) + } + if metadata := strings.TrimSpace(turnMetadata); metadata != "" { + headers.Set(openAIWSTurnMetadataHeader, metadata) + } + + if account != nil && account.Type == AccountTypeOAuth { + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + headers.Set("chatgpt-account-id", chatgptAccountID) + } + if isCodexCLI { + headers.Set("originator", "codex_cli_rs") + } else { + headers.Set("originator", "opencode") + } + } + + betaValue := openAIWSBetaV2Value + if decision.Transport == OpenAIUpstreamTransportResponsesWebsocket { + betaValue = openAIWSBetaV1Value + } + headers.Set("OpenAI-Beta", betaValue) + + customUA := "" + if account != nil { + customUA = account.GetOpenAIUserAgent() + } + if strings.TrimSpace(customUA) != "" { + headers.Set("user-agent", customUA) + } else if c != nil { + if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" { + headers.Set("user-agent", ua) + } + } + if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + headers.Set("user-agent", codexCLIUserAgent) + } + if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) { + headers.Set("user-agent", codexCLIUserAgent) + } + + return headers, sessionResolution +} + +func (s *OpenAIGatewayService) buildOpenAIWSCreatePayload(reqBody map[string]any, account *Account) map[string]any { + // OpenAI WS Mode 协议:response.create 字段与 HTTP /responses 基本一致。 + // 保留 stream 字段(与 Codex CLI 一致),仅移除 background。 + payload := make(map[string]any, len(reqBody)+1) + for k, v := range reqBody { + payload[k] = v + } + + delete(payload, "background") + if _, exists := payload["stream"]; !exists { + payload["stream"] = true + } + payload["type"] = "response.create" + + // OAuth 默认保持 store=false,避免误依赖服务端历史。 + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + payload["store"] = false + } + return payload +} + +func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) { + if len(payload) == 0 { + return + } + metadata := strings.TrimSpace(turnMetadata) + if metadata == "" { + return + } + + switch existing := payload["client_metadata"].(type) { + case map[string]any: + existing[openAIWSTurnMetadataHeader] = metadata + payload["client_metadata"] = existing + case map[string]string: + next := make(map[string]any, len(existing)+1) + for k, v := range existing { + next[k] = v + } + next[openAIWSTurnMetadataHeader] = metadata + payload["client_metadata"] = next + default: + payload["client_metadata"] = map[string]any{ + openAIWSTurnMetadataHeader: metadata, + } + } +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool { + if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() { + return true + } + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery { + return true + } + return false +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true + } + if len(reqBody) == 0 { + return false + } + rawStore, ok := reqBody["store"] + if !ok { + return false + } + storeEnabled, ok := rawStore.(bool) + if !ok { + return false + } + return !storeEnabled +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true + } + if len(reqBody) == 0 { + return false + } + storeValue := gjson.GetBytes(reqBody, "store") + if !storeValue.Exists() { + return false + } + if storeValue.Type != gjson.True && storeValue.Type != gjson.False { + return false + } + return !storeValue.Bool() +} + +func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string { + if s == nil || s.cfg == nil { + return openAIWSStoreDisabledConnModeStrict + } + mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode)) + switch mode { + case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff: + return mode + case "": + // 兼容旧配置:仅配置了布尔开关时按旧语义推导。 + if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { + return openAIWSStoreDisabledConnModeStrict + } + return openAIWSStoreDisabledConnModeOff + default: + return openAIWSStoreDisabledConnModeStrict + } +} + +func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool { + switch mode { + case openAIWSStoreDisabledConnModeOff: + return false + case openAIWSStoreDisabledConnModeAdaptive: + reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_") + switch reason { + case "policy_violation", "message_too_big", "auth_failed", "write_request", "write": + return true + default: + return false + } + default: + return true + } +} + +func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) { + return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes) +} + +func dropPreviousResponseIDFromRawPayloadWithDeleteFn( + payload []byte, + deleteFn func([]byte, string) ([]byte, error), +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + if !gjson.GetBytes(payload, "previous_response_id").Exists() { + return payload, false, nil + } + if deleteFn == nil { + deleteFn = sjson.DeleteBytes + } + + updated := payload + for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses && + gjson.GetBytes(updated, "previous_response_id").Exists(); i++ { + next, err := deleteFn(updated, "previous_response_id") + if err != nil { + return payload, false, err + } + updated = next + } + return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil +} + +func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) { + normalizedPrevID := strings.TrimSpace(previousResponseID) + if len(payload) == 0 || normalizedPrevID == "" { + return payload, nil + } + updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID) + if err == nil { + return updated, nil + } + + var reqBody map[string]any + if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil { + return nil, err + } + reqBody["previous_response_id"] = normalizedPrevID + rebuilt, marshalErr := json.Marshal(reqBody) + if marshalErr != nil { + return nil, marshalErr + } + return rebuilt, nil +} + +func shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled bool, + turn int, + hasFunctionCallOutput bool, + currentPreviousResponseID string, + expectedPreviousResponseID string, +) bool { + if !storeDisabled || turn <= 1 || !hasFunctionCallOutput { + return false + } + if strings.TrimSpace(currentPreviousResponseID) != "" { + return false + } + return strings.TrimSpace(expectedPreviousResponseID) != "" +} + +func alignStoreDisabledPreviousResponseID( + payload []byte, + expectedPreviousResponseID string, +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + expected := strings.TrimSpace(expectedPreviousResponseID) + if expected == "" { + return payload, false, nil + } + current := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + if current == "" || current == expected { + return payload, false, nil + } + + withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + if dropErr != nil { + return payload, false, dropErr + } + if !removed { + return payload, false, nil + } + updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected) + if setErr != nil { + return payload, false, setErr + } + return updated, true, nil +} + +func cloneOpenAIWSPayloadBytes(payload []byte) []byte { + if len(payload) == 0 { + return nil + } + cloned := make([]byte, len(payload)) + copy(cloned, payload) + return cloned +} + +func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage { + if items == nil { + return nil + } + cloned := make([]json.RawMessage, 0, len(items)) + for idx := range items { + cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx]))) + } + return cloned +} + +func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return nil, errors.New("json is empty") + } + var decoded any + if err := json.Unmarshal(trimmed, &decoded); err != nil { + return nil, err + } + return json.Marshal(decoded) +} + +func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte { + normalized, err := normalizeOpenAIWSJSONForCompare(raw) + if err != nil { + return bytes.TrimSpace(raw) + } + return normalized +} + +func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) { + if len(payload) == 0 { + return nil, errors.New("payload is empty") + } + var decoded map[string]any + if err := json.Unmarshal(payload, &decoded); err != nil { + return nil, err + } + delete(decoded, "input") + delete(decoded, "previous_response_id") + return json.Marshal(decoded) +} + +func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) { + if len(payload) == 0 { + return nil, false, nil + } + inputValue := gjson.GetBytes(payload, "input") + if !inputValue.Exists() { + return nil, false, nil + } + if inputValue.Type == gjson.JSON { + raw := strings.TrimSpace(inputValue.Raw) + if strings.HasPrefix(raw, "[") { + var items []json.RawMessage + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil, true, err + } + return items, true, nil + } + return []json.RawMessage{json.RawMessage(raw)}, true, nil + } + if inputValue.Type == gjson.String { + encoded, _ := json.Marshal(inputValue.String()) + return []json.RawMessage{encoded}, true, nil + } + return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil +} + +func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) { + previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload) + if prevErr != nil { + return false, prevErr + } + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return false, currentErr + } + if !previousExists && !currentExists { + return true, nil + } + if !previousExists { + return len(currentItems) == 0, nil + } + if !currentExists { + return len(previousItems) == 0, nil + } + if len(currentItems) < len(previousItems) { + return false, nil + } + + for idx := range previousItems { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false, nil + } + } + return true, nil +} + +func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool { + if len(prefix) == 0 { + return true + } + if len(items) < len(prefix) { + return false + } + for idx := range prefix { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false + } + } + return true +} + +func buildOpenAIWSReplayInputSequence( + previousFullInput []json.RawMessage, + previousFullInputExists bool, + currentPayload []byte, + hasPreviousResponseID bool, +) ([]json.RawMessage, bool, error) { + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return nil, false, currentErr + } + if !hasPreviousResponseID { + return cloneOpenAIWSRawMessages(currentItems), currentExists, nil + } + if !previousFullInputExists { + return cloneOpenAIWSRawMessages(currentItems), currentExists, nil + } + if !currentExists || len(currentItems) == 0 { + return cloneOpenAIWSRawMessages(previousFullInput), true, nil + } + if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) { + return cloneOpenAIWSRawMessages(currentItems), true, nil + } + merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems)) + merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...) + merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...) + return merged, true, nil +} + +func setOpenAIWSPayloadInputSequence( + payload []byte, + fullInput []json.RawMessage, + fullInputExists bool, +) ([]byte, error) { + if !fullInputExists { + return payload, nil + } + // Preserve [] vs null semantics when input exists but is empty. + inputForMarshal := fullInput + if inputForMarshal == nil { + inputForMarshal = []json.RawMessage{} + } + inputRaw, marshalErr := json.Marshal(inputForMarshal) + if marshalErr != nil { + return nil, marshalErr + } + return sjson.SetRawBytes(payload, "input", inputRaw) +} + +func shouldKeepIngressPreviousResponseID( + previousPayload []byte, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, +) (bool, string, error) { + if hasFunctionCallOutput { + return true, "has_function_call_output", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if len(previousPayload) == 0 { + return false, "missing_previous_turn_payload", nil + } + + previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload) + if previousComparableErr != nil { + return false, "non_input_compare_error", previousComparableErr + } + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +type openAIWSIngressPreviousTurnStrictState struct { + nonInputComparable []byte +} + +func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) { + if len(payload) == 0 { + return nil, nil + } + nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload) + if nonInputErr != nil { + return nil, nonInputErr + } + return &openAIWSIngressPreviousTurnStrictState{ + nonInputComparable: nonInputComparable, + }, nil +} + +func shouldKeepIngressPreviousResponseIDWithStrictState( + previousState *openAIWSIngressPreviousTurnStrictState, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, +) (bool, string, error) { + if hasFunctionCallOutput { + return true, "has_function_call_output", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if previousState == nil { + return false, "missing_previous_turn_payload", nil + } + + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousState.nonInputComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +func (s *OpenAIGatewayService) forwardOpenAIWSV2( + ctx context.Context, + c *gin.Context, + account *Account, + reqBody map[string]any, + token string, + decision OpenAIWSProtocolDecision, + isCodexCLI bool, + reqStream bool, + originalModel string, + mappedModel string, + startTime time.Time, + attempt int, + lastFailureReason string, +) (*OpenAIForwardResult, error) { + if s == nil || account == nil { + return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil")) + } + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return nil, wrapOpenAIWSFallback("build_ws_url", err) + } + wsHost := "-" + wsPath := "-" + if parsed, parseErr := url.Parse(wsURL); parseErr == nil && parsed != nil { + if h := strings.TrimSpace(parsed.Host); h != "" { + wsHost = normalizeOpenAIWSLogValue(h) + } + if p := strings.TrimSpace(parsed.Path); p != "" { + wsPath = normalizeOpenAIWSLogValue(p) + } + } + logOpenAIWSModeDebug( + "dial_target account_id=%d account_type=%s ws_host=%s ws_path=%s", + account.ID, + account.Type, + wsHost, + wsPath, + ) + + payload := s.buildOpenAIWSCreatePayload(reqBody, account) + payloadStrategy, removedKeys := applyOpenAIWSRetryPayloadStrategy(payload, attempt) + previousResponseID := openAIWSPayloadString(payload, "previous_response_id") + previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + promptCacheKey := openAIWSPayloadString(payload, "prompt_cache_key") + _, hasTools := payload["tools"] + debugEnabled := isOpenAIWSModeDebugEnabled() + payloadBytes := -1 + resolvePayloadBytes := func() int { + if payloadBytes >= 0 { + return payloadBytes + } + payloadBytes = len(payloadAsJSONBytes(payload)) + return payloadBytes + } + streamValue := "-" + if raw, ok := payload["stream"]; ok { + streamValue = normalizeOpenAIWSLogValue(strings.TrimSpace(fmt.Sprintf("%v", raw))) + } + turnState := "" + turnMetadata := "" + if c != nil && c.Request != nil { + turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) + turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)) + } + setOpenAIWSTurnMetadata(payload, turnMetadata) + payloadEventType := openAIWSPayloadString(payload, "type") + if payloadEventType == "" { + payloadEventType = "response.create" + } + if s.shouldEmitOpenAIWSPayloadSchema(attempt) { + logOpenAIWSModeInfo( + "[debug] payload_schema account_id=%d attempt=%d event=%s payload_keys=%s payload_bytes=%d payload_key_sizes=%s input_summary=%s stream=%s payload_strategy=%s removed_keys=%s has_previous_response_id=%v has_prompt_cache_key=%v has_tools=%v", + account.ID, + attempt, + payloadEventType, + normalizeOpenAIWSLogValue(strings.Join(sortedKeys(payload), ",")), + resolvePayloadBytes(), + normalizeOpenAIWSLogValue(summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)), + normalizeOpenAIWSLogValue(summarizeOpenAIWSInput(payload["input"])), + streamValue, + normalizeOpenAIWSLogValue(payloadStrategy), + normalizeOpenAIWSLogValue(strings.Join(removedKeys, ",")), + previousResponseID != "", + promptCacheKey != "", + hasTools, + ) + } + + stateStore := s.getOpenAIWSStateStore() + groupID := getOpenAIGroupIDFromContext(c) + sessionHash := s.GenerateSessionHash(c, nil) + if sessionHash == "" { + var legacySessionHash string + sessionHash, legacySessionHash = openAIWSSessionHashesFromID(promptCacheKey) + attachOpenAILegacySessionHashToGin(c, legacySessionHash) + } + if turnState == "" && stateStore != nil && sessionHash != "" { + if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { + turnState = savedTurnState + } + } + preferredConnID := "" + if stateStore != nil && previousResponseID != "" { + if connID, ok := stateStore.GetResponseConn(previousResponseID); ok { + preferredConnID = connID + } + } + storeDisabled := s.isOpenAIWSStoreDisabledInRequest(reqBody, account) + if stateStore != nil && storeDisabled && previousResponseID == "" && sessionHash != "" { + if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { + preferredConnID = connID + } + } + storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() + forceNewConnByPolicy := shouldForceNewConnOnStoreDisabled(storeDisabledConnMode, lastFailureReason) + forceNewConn := forceNewConnByPolicy && storeDisabled && previousResponseID == "" && sessionHash != "" && preferredConnID == "" + wsHeaders, sessionResolution := s.buildOpenAIWSHeaders(c, account, token, decision, isCodexCLI, turnState, turnMetadata, promptCacheKey) + logOpenAIWSModeDebug( + "acquire_start account_id=%d account_type=%s transport=%s preferred_conn_id=%s has_previous_response_id=%v session_hash=%s has_turn_state=%v turn_state_len=%d has_turn_metadata=%v turn_metadata_len=%d store_disabled=%v store_disabled_conn_mode=%s retry_last_reason=%s force_new_conn=%v header_user_agent=%s header_openai_beta=%s header_originator=%s header_accept_language=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_prompt_cache_key=%v has_chatgpt_account_id=%v has_authorization=%v has_session_id=%v has_conversation_id=%v proxy_enabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + previousResponseID != "", + truncateOpenAIWSLogValue(sessionHash, 12), + turnState != "", + len(turnState), + turnMetadata != "", + len(turnMetadata), + storeDisabled, + normalizeOpenAIWSLogValue(storeDisabledConnMode), + truncateOpenAIWSLogValue(lastFailureReason, openAIWSLogValueMaxLen), + forceNewConn, + openAIWSHeaderValueForLog(wsHeaders, "user-agent"), + openAIWSHeaderValueForLog(wsHeaders, "openai-beta"), + openAIWSHeaderValueForLog(wsHeaders, "originator"), + openAIWSHeaderValueForLog(wsHeaders, "accept-language"), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + promptCacheKey != "", + hasOpenAIWSHeader(wsHeaders, "chatgpt-account-id"), + hasOpenAIWSHeader(wsHeaders, "authorization"), + hasOpenAIWSHeader(wsHeaders, "session_id"), + hasOpenAIWSHeader(wsHeaders, "conversation_id"), + account.ProxyID != nil && account.Proxy != nil, + ) + + acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout()) + defer acquireCancel() + + lease, err := s.getOpenAIWSConnPool().Acquire(acquireCtx, openAIWSAcquireRequest{ + Account: account, + WSURL: wsURL, + Headers: wsHeaders, + PreferredConnID: preferredConnID, + ForceNewConn: forceNewConn, + ProxyURL: func() string { + if account.ProxyID != nil && account.Proxy != nil { + return account.Proxy.URL() + } + return "" + }(), + }) + if err != nil { + dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(err) + logOpenAIWSModeInfo( + "acquire_fail account_id=%d account_type=%s transport=%s reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_new_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(err)), + dialStatus, + dialClass, + dialCloseStatus, + truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen), + dialRespServer, + dialRespVia, + dialRespCFRay, + dialRespReqID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + forceNewConn, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err) + } + defer lease.Release() + connID := strings.TrimSpace(lease.ConnID()) + logOpenAIWSModeDebug( + "connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + connID, + lease.Reused(), + lease.ConnPickDuration().Milliseconds(), + lease.QueueWaitDuration().Milliseconds(), + previousResponseID != "", + ) + if previousResponseID != "" { + logOpenAIWSModeInfo( + "continuation_probe account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s conn_reused=%v store_disabled=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v", + account.ID, + account.Type, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(previousResponseIDKind), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + lease.Reused(), + storeDisabled, + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + turnState != "", + len(turnState), + promptCacheKey != "", + ) + } + if c != nil { + SetOpsLatencyMs(c, OpsOpenAIWSConnPickMsKey, lease.ConnPickDuration().Milliseconds()) + SetOpsLatencyMs(c, OpsOpenAIWSQueueWaitMsKey, lease.QueueWaitDuration().Milliseconds()) + c.Set(OpsOpenAIWSConnReusedKey, lease.Reused()) + if connID != "" { + c.Set(OpsOpenAIWSConnIDKey, connID) + } + } + + handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)) + logOpenAIWSModeDebug( + "handshake account_id=%d conn_id=%s has_turn_state=%v turn_state_len=%d", + account.ID, + connID, + handshakeTurnState != "", + len(handshakeTurnState), + ) + if handshakeTurnState != "" { + if stateStore != nil && sessionHash != "" { + stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL()) + } + if c != nil { + c.Header(http.CanonicalHeaderKey(openAIWSTurnStateHeader), handshakeTurnState) + } + } + + if err := s.performOpenAIWSGeneratePrewarm( + ctx, + lease, + decision, + payload, + previousResponseID, + reqBody, + account, + stateStore, + groupID, + ); err != nil { + return nil, err + } + + if err := lease.WriteJSONWithContextTimeout(ctx, payload, s.openAIWSWriteTimeout()); err != nil { + lease.MarkBroken() + logOpenAIWSModeInfo( + "write_request_fail account_id=%d conn_id=%s cause=%s payload_bytes=%d", + account.ID, + connID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + resolvePayloadBytes(), + ) + return nil, wrapOpenAIWSFallback("write_request", err) + } + if debugEnabled { + logOpenAIWSModeDebug( + "write_request_sent account_id=%d conn_id=%s stream=%v payload_bytes=%d previous_response_id=%s", + account.ID, + connID, + reqStream, + resolvePayloadBytes(), + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + ) + } + + usage := &OpenAIUsage{} + var firstTokenMs *int + responseID := "" + var finalResponse []byte + wroteDownstream := false + needModelReplace := originalModel != mappedModel + var mappedModelBytes []byte + if needModelReplace && mappedModel != "" { + mappedModelBytes = []byte(mappedModel) + } + bufferedStreamEvents := make([][]byte, 0, 4) + eventCount := 0 + tokenEventCount := 0 + terminalEventCount := 0 + bufferedEventCount := 0 + flushedBufferedEventCount := 0 + firstEventType := "" + lastEventType := "" + + var flusher http.Flusher + if reqStream { + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), http.Header{}, s.responseHeaderFilter) + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + f, ok := c.Writer.(http.Flusher) + if !ok { + lease.MarkBroken() + return nil, wrapOpenAIWSFallback("streaming_not_supported", errors.New("streaming not supported")) + } + flusher = f + } + + clientDisconnected := false + flushBatchSize := s.openAIWSEventFlushBatchSize() + flushInterval := s.openAIWSEventFlushInterval() + pendingFlushEvents := 0 + lastFlushAt := time.Now() + flushStreamWriter := func(force bool) { + if clientDisconnected || flusher == nil || pendingFlushEvents <= 0 { + return + } + if !force && flushBatchSize > 1 && pendingFlushEvents < flushBatchSize { + if flushInterval <= 0 || time.Since(lastFlushAt) < flushInterval { + return + } + } + flusher.Flush() + pendingFlushEvents = 0 + lastFlushAt = time.Now() + } + emitStreamMessage := func(message []byte, forceFlush bool) { + if clientDisconnected { + return + } + frame := make([]byte, 0, len(message)+8) + frame = append(frame, "data: "...) + frame = append(frame, message...) + frame = append(frame, '\n', '\n') + _, wErr := c.Writer.Write(frame) + if wErr == nil { + wroteDownstream = true + pendingFlushEvents++ + flushStreamWriter(forceFlush) + return + } + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d", account.ID) + } + flushBufferedStreamEvents := func(reason string) { + if len(bufferedStreamEvents) == 0 { + return + } + flushed := len(bufferedStreamEvents) + for _, buffered := range bufferedStreamEvents { + emitStreamMessage(buffered, false) + } + bufferedStreamEvents = bufferedStreamEvents[:0] + flushStreamWriter(true) + flushedBufferedEventCount += flushed + if debugEnabled { + logOpenAIWSModeDebug( + "buffer_flush account_id=%d conn_id=%s reason=%s flushed=%d total_flushed=%d client_disconnected=%v", + account.ID, + connID, + truncateOpenAIWSLogValue(reason, openAIWSLogValueMaxLen), + flushed, + flushedBufferedEventCount, + clientDisconnected, + ) + } + } + + readTimeout := s.openAIWSReadTimeout() + + for { + message, readErr := lease.ReadMessageWithContextTimeout(ctx, readTimeout) + if readErr != nil { + lease.MarkBroken() + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "read_fail account_id=%d conn_id=%s wrote_downstream=%v close_status=%s close_reason=%s cause=%s events=%d token_events=%d terminal_events=%d buffered_pending=%d buffered_flushed=%d first_event=%s last_event=%s", + account.ID, + connID, + wroteDownstream, + closeStatus, + closeReason, + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + eventCount, + tokenEventCount, + terminalEventCount, + len(bufferedStreamEvents), + flushedBufferedEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + ) + if !wroteDownstream { + return nil, wrapOpenAIWSFallback(classifyOpenAIWSReadFallbackReason(readErr), readErr) + } + if clientDisconnected { + break + } + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(readErr.Error()), "") + return nil, fmt.Errorf("openai ws read event: %w", readErr) + } + + eventType, eventResponseID, responseField := parseOpenAIWSEventEnvelope(message) + if eventType == "" { + continue + } + eventCount++ + if firstEventType == "" { + firstEventType = eventType + } + lastEventType = eventType + + if responseID == "" && eventResponseID != "" { + responseID = eventResponseID + } + + isTokenEvent := isOpenAIWSTokenEvent(eventType) + if isTokenEvent { + tokenEventCount++ + } + isTerminalEvent := isOpenAIWSTerminalEvent(eventType) + if isTerminalEvent { + terminalEventCount++ + } + if firstTokenMs == nil && isTokenEvent { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if debugEnabled && shouldLogOpenAIWSEvent(eventCount, eventType) { + logOpenAIWSModeDebug( + "event_received account_id=%d conn_id=%s idx=%d type=%s bytes=%d token=%v terminal=%v buffered_pending=%d", + account.ID, + connID, + eventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(message), + isTokenEvent, + isTerminalEvent, + len(bufferedStreamEvents), + ) + } + + if !clientDisconnected { + if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(message, mappedModelBytes) { + message = replaceOpenAIWSMessageModel(message, mappedModel, originalModel) + } + if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(message) { + if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(message); changed { + message = corrected + } + } + } + if openAIWSEventShouldParseUsage(eventType) { + parseOpenAIWSResponseUsageFromCompletedEvent(message, usage) + } + + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "Upstream websocket error" + } + fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSModeInfo( + "error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", + account.ID, + connID, + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + canFallback, + errCode, + errType, + errMessage, + ) + if fallbackReason == "previous_response_not_found" { + logOpenAIWSModeInfo( + "previous_response_not_found_diag account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s event_idx=%d req_stream=%v store_disabled=%v conn_reused=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v err_code=%s err_type=%s err_message=%s", + account.ID, + account.Type, + connID, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(previousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + eventCount, + reqStream, + storeDisabled, + lease.Reused(), + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + turnState != "", + len(turnState), + promptCacheKey != "", + errCode, + errType, + errMessage, + ) + } + // error 事件后连接不再可复用,避免回池后污染下一请求。 + lease.MarkBroken() + if !wroteDownstream && canFallback { + return nil, wrapOpenAIWSFallback(fallbackReason, errors.New(errMsg)) + } + statusCode := openAIWSErrorHTTPStatusFromRaw(errCodeRaw, errTypeRaw) + setOpsUpstreamError(c, statusCode, errMsg, "") + if reqStream && !clientDisconnected { + flushBufferedStreamEvents("error_event") + emitStreamMessage(message, true) + } + if !reqStream { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": errMsg, + }, + }) + } + return nil, fmt.Errorf("openai ws error event: %s", errMsg) + } + + if reqStream { + // 在首个 token 前先缓冲事件(如 response.created), + // 以便上游早期断连时仍可安全回退到 HTTP,不给下游发送半截流。 + shouldBuffer := firstTokenMs == nil && !isTokenEvent && !isTerminalEvent + if shouldBuffer { + buffered := make([]byte, len(message)) + copy(buffered, message) + bufferedStreamEvents = append(bufferedStreamEvents, buffered) + bufferedEventCount++ + if debugEnabled && shouldLogOpenAIWSBufferedEvent(bufferedEventCount) { + logOpenAIWSModeDebug( + "buffer_enqueue account_id=%d conn_id=%s idx=%d event_idx=%d event_type=%s buffer_size=%d", + account.ID, + connID, + bufferedEventCount, + eventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(bufferedStreamEvents), + ) + } + } else { + flushBufferedStreamEvents(eventType) + emitStreamMessage(message, isTerminalEvent) + } + } else { + if responseField.Exists() && responseField.Type == gjson.JSON { + finalResponse = []byte(responseField.Raw) + } + } + + if isTerminalEvent { + break + } + } + + if !reqStream { + if len(finalResponse) == 0 { + logOpenAIWSModeInfo( + "missing_final_response account_id=%d conn_id=%s events=%d token_events=%d terminal_events=%d wrote_downstream=%v", + account.ID, + connID, + eventCount, + tokenEventCount, + terminalEventCount, + wroteDownstream, + ) + if !wroteDownstream { + return nil, wrapOpenAIWSFallback("missing_final_response", errors.New("no terminal response payload")) + } + return nil, errors.New("ws finished without final response") + } + + if needModelReplace { + finalResponse = s.replaceModelInResponseBody(finalResponse, mappedModel, originalModel) + } + finalResponse = s.correctToolCallsInResponseBody(finalResponse) + populateOpenAIUsageFromResponseJSON(finalResponse, usage) + if responseID == "" { + responseID = strings.TrimSpace(gjson.GetBytes(finalResponse, "id").String()) + } + + c.Data(http.StatusOK, "application/json", finalResponse) + } else { + flushStreamWriter(true) + } + + if responseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) + stateStore.BindResponseConn(responseID, lease.ConnID(), ttl) + } + if stateStore != nil && storeDisabled && sessionHash != "" { + stateStore.BindSessionConn(groupID, sessionHash, lease.ConnID(), s.openAIWSSessionStickyTTL()) + } + firstTokenMsValue := -1 + if firstTokenMs != nil { + firstTokenMsValue = *firstTokenMs + } + logOpenAIWSModeDebug( + "completed account_id=%d conn_id=%s response_id=%s stream=%v duration_ms=%d events=%d token_events=%d terminal_events=%d buffered_events=%d buffered_flushed=%d first_event=%s last_event=%s first_token_ms=%d wrote_downstream=%v client_disconnected=%v", + account.ID, + connID, + truncateOpenAIWSLogValue(strings.TrimSpace(responseID), openAIWSIDValueMaxLen), + reqStream, + time.Since(startTime).Milliseconds(), + eventCount, + tokenEventCount, + terminalEventCount, + bufferedEventCount, + flushedBufferedEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + firstTokenMsValue, + wroteDownstream, + clientDisconnected, + ) + + return &OpenAIForwardResult{ + RequestID: responseID, + Usage: *usage, + Model: originalModel, + ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +// ProxyResponsesWebSocketFromClient 处理客户端入站 WebSocket(OpenAI Responses WS Mode)并转发到上游。 +// 当前实现按“单请求 -> 终止事件 -> 下一请求”的顺序代理,适配 Codex CLI 的 turn 模式。 +func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *Account, + token string, + firstClientMessage []byte, + hooks *OpenAIWSIngressHooks, +) error { + if s == nil { + return errors.New("service is nil") + } + if c == nil { + return errors.New("gin context is nil") + } + if clientConn == nil { + return errors.New("client websocket is nil") + } + if account == nil { + return errors.New("account is nil") + } + if strings.TrimSpace(token) == "" { + return errors.New("token is empty") + } + + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) + modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled + ingressMode := OpenAIWSIngressModeShared + if modeRouterV2Enabled { + ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault) + if ingressMode == OpenAIWSIngressModeOff { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode is disabled for this account", + nil, + ) + } + } + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) + } + dedicatedMode := modeRouterV2Enabled && ingressMode == OpenAIWSIngressModeDedicated + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return fmt.Errorf("build ws url: %w", err) + } + wsHost := "-" + wsPath := "-" + if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil { + wsHost = normalizeOpenAIWSLogValue(parsedURL.Host) + wsPath = normalizeOpenAIWSLogValue(parsedURL.Path) + } + debugEnabled := isOpenAIWSModeDebugEnabled() + + type openAIWSClientPayload struct { + payloadRaw []byte + rawForHash []byte + promptCacheKey string + previousResponseID string + originalModel string + payloadBytes int + } + + applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) { + next, err := sjson.SetBytes(current, path, value) + if err == nil { + return next, nil + } + + // 仅在确实需要修改 payload 且 sjson 失败时,退回 map 路径确保兼容性。 + payload := make(map[string]any) + if unmarshalErr := json.Unmarshal(current, &payload); unmarshalErr != nil { + return nil, err + } + switch path { + case "type", "model": + payload[path] = value + case "client_metadata." + openAIWSTurnMetadataHeader: + setOpenAIWSTurnMetadata(payload, fmt.Sprintf("%v", value)) + default: + return nil, err + } + rebuilt, marshalErr := json.Marshal(payload) + if marshalErr != nil { + return nil, marshalErr + } + return rebuilt, nil + } + + parseClientPayload := func(raw []byte) (openAIWSClientPayload, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "empty websocket request payload", nil) + } + if !gjson.ValidBytes(trimmed) { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) + } + + values := gjson.GetManyBytes(trimmed, "type", "model", "prompt_cache_key", "previous_response_id") + eventType := strings.TrimSpace(values[0].String()) + normalized := trimmed + switch eventType { + case "": + eventType = "response.create" + next, setErr := applyPayloadMutation(normalized, "type", eventType) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + case "response.create": + case "response.append": + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "response.append is not supported in ws v2; use response.create with previous_response_id", + nil, + ) + default: + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + fmt.Sprintf("unsupported websocket request type: %s", eventType), + nil, + ) + } + + originalModel := strings.TrimSpace(values[1].String()) + if originalModel == "" { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "model is required in response.create payload", + nil, + ) + } + promptCacheKey := strings.TrimSpace(values[2].String()) + previousResponseID := strings.TrimSpace(values[3].String()) + previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + if previousResponseID != "" && previousResponseIDKind == OpenAIPreviousResponseIDKindMessageID { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "previous_response_id must be a response.id (resp_*), not a message id", + nil, + ) + } + if turnMetadata := strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)); turnMetadata != "" { + next, setErr := applyPayloadMutation(normalized, "client_metadata."+openAIWSTurnMetadataHeader, turnMetadata) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + } + mappedModel := account.GetMappedModel(originalModel) + if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { + mappedModel = normalizedModel + } + if mappedModel != originalModel { + next, setErr := applyPayloadMutation(normalized, "model", mappedModel) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + } + + return openAIWSClientPayload{ + payloadRaw: normalized, + rawForHash: trimmed, + promptCacheKey: promptCacheKey, + previousResponseID: previousResponseID, + originalModel: originalModel, + payloadBytes: len(normalized), + }, nil + } + + firstPayload, err := parseClientPayload(firstClientMessage) + if err != nil { + return err + } + + turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) + stateStore := s.getOpenAIWSStateStore() + groupID := getOpenAIGroupIDFromContext(c) + sessionHash := s.GenerateSessionHash(c, firstPayload.rawForHash) + if turnState == "" && stateStore != nil && sessionHash != "" { + if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { + turnState = savedTurnState + } + } + + preferredConnID := "" + if stateStore != nil && firstPayload.previousResponseID != "" { + if connID, ok := stateStore.GetResponseConn(firstPayload.previousResponseID); ok { + preferredConnID = connID + } + } + + storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account) + storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() + if stateStore != nil && storeDisabled && firstPayload.previousResponseID == "" && sessionHash != "" { + if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { + preferredConnID = connID + } + } + + isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey) + baseAcquireReq := openAIWSAcquireRequest{ + Account: account, + WSURL: wsURL, + Headers: wsHeaders, + ProxyURL: func() string { + if account.ProxyID != nil && account.Proxy != nil { + return account.Proxy.URL() + } + return "" + }(), + ForceNewConn: false, + } + pool := s.getOpenAIWSConnPool() + if pool == nil { + return errors.New("openai ws conn pool is nil") + } + + logOpenAIWSModeInfo( + "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s store_disabled=%v has_session_hash=%v has_previous_response_id=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + wsHost, + wsPath, + normalizeOpenAIWSLogValue(ingressMode), + storeDisabled, + sessionHash != "", + firstPayload.previousResponseID != "", + ) + + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + wsHost, + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + sessionHash != "", + firstPayload.previousResponseID != "", + storeDisabled, + ) + } + if firstPayload.previousResponseID != "" { + firstPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(firstPayload.previousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_continuation_probe account_id=%d turn=%d previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s session_hash=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v", + account.ID, + 1, + truncateOpenAIWSLogValue(firstPayload.previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(firstPreviousResponseIDKind), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"), + turnState != "", + len(turnState), + firstPayload.promptCacheKey != "", + storeDisabled, + ) + } + + acquireTimeout := s.openAIWSAcquireTimeout() + if acquireTimeout <= 0 { + acquireTimeout = 30 * time.Second + } + + acquireTurnLease := func(turn int, preferred string, forcePreferredConn bool) (*openAIWSConnLease, error) { + req := cloneOpenAIWSAcquireRequest(baseAcquireReq) + req.PreferredConnID = strings.TrimSpace(preferred) + req.ForcePreferredConn = forcePreferredConn + // dedicated 模式下每次获取均新建连接,避免跨会话复用残留上下文。 + req.ForceNewConn = dedicatedMode + acquireCtx, acquireCancel := context.WithTimeout(ctx, acquireTimeout) + lease, acquireErr := pool.Acquire(acquireCtx, req) + acquireCancel() + if acquireErr != nil { + dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(acquireErr) + logOpenAIWSModeInfo( + "ingress_ws_upstream_acquire_fail account_id=%d turn=%d reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_preferred_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + turn, + normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(acquireErr)), + dialStatus, + dialClass, + dialCloseStatus, + truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen), + dialRespServer, + dialRespVia, + dialRespCFRay, + dialRespReqID, + truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), + forcePreferredConn, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream continuation connection is unavailable; please restart the conversation", + acquireErr, + ) + } + if errors.Is(acquireErr, context.DeadlineExceeded) || errors.Is(acquireErr, errOpenAIWSConnQueueFull) { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket is busy, please retry later", + acquireErr, + ) + } + return nil, acquireErr + } + connID := strings.TrimSpace(lease.ConnID()) + if handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)); handshakeTurnState != "" { + turnState = handshakeTurnState + if stateStore != nil && sessionHash != "" { + stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL()) + } + updatedHeaders := cloneHeader(baseAcquireReq.Headers) + if updatedHeaders == nil { + updatedHeaders = make(http.Header) + } + updatedHeaders.Set(openAIWSTurnStateHeader, handshakeTurnState) + baseAcquireReq.Headers = updatedHeaders + } + logOpenAIWSModeInfo( + "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + lease.Reused(), + lease.ConnPickDuration().Milliseconds(), + lease.QueueWaitDuration().Milliseconds(), + truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), + ) + return lease, nil + } + + writeClientMessage := func(message []byte) error { + writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + defer cancel() + return clientConn.Write(writeCtx, coderws.MessageText, message) + } + + readClientMessage := func() ([]byte, error) { + msgType, payload, readErr := clientConn.Read(ctx) + if readErr != nil { + return nil, readErr + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + fmt.Sprintf("unsupported websocket client message type: %s", msgType.String()), + nil, + ) + } + return payload, nil + } + + sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) { + if lease == nil { + return nil, errors.New("upstream websocket lease is nil") + } + turnStart := time.Now() + wroteDownstream := false + if err := lease.WriteJSONWithContextTimeout(ctx, json.RawMessage(payload), s.openAIWSWriteTimeout()); err != nil { + return nil, wrapOpenAIWSIngressTurnError( + "write_upstream", + fmt.Errorf("write upstream websocket request: %w", err), + false, + ) + } + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_turn_request_sent account_id=%d turn=%d conn_id=%s payload_bytes=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + payloadBytes, + ) + } + + responseID := "" + usage := OpenAIUsage{} + var firstTokenMs *int + reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true) + turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID) + turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key") + turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account) + turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + eventCount := 0 + tokenEventCount := 0 + terminalEventCount := 0 + firstEventType := "" + lastEventType := "" + needModelReplace := false + clientDisconnected := false + mappedModel := "" + var mappedModelBytes []byte + if originalModel != "" { + mappedModel = account.GetMappedModel(originalModel) + if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { + mappedModel = normalizedModel + } + needModelReplace = mappedModel != "" && mappedModel != originalModel + if needModelReplace { + mappedModelBytes = []byte(mappedModel) + } + } + for { + upstreamMessage, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) + if readErr != nil { + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnError( + "read_upstream", + fmt.Errorf("read upstream websocket event: %w", readErr), + wroteDownstream, + ) + } + + eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(upstreamMessage) + if responseID == "" && eventResponseID != "" { + responseID = eventResponseID + } + if eventType != "" { + eventCount++ + if firstEventType == "" { + firstEventType = eventType + } + lastEventType = eventType + } + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage) + fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound && + turnPreviousResponseID != "" && + !turnHasFunctionCallOutput && + s.openAIWSIngressPreviousResponseRecoveryEnabled() && + !wroteDownstream + if recoverablePrevNotFound { + // 可恢复场景使用非 error 关键字日志,避免被 LegacyPrintf 误判为 ERROR 级别。 + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + errCode, + errType, + errMessage, + truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + turnStoreDisabled, + turnPromptCacheKey != "", + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + errCode, + errType, + errMessage, + truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + turnStoreDisabled, + turnPromptCacheKey != "", + ) + } + // previous_response_not_found 在 ingress 模式支持单次恢复重试: + // 不把该 error 直接下发客户端,而是由上层去掉 previous_response_id 后重放当前 turn。 + if recoverablePrevNotFound { + lease.MarkBroken() + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "previous response not found" + } + return nil, wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New(errMsg), + false, + ) + } + } + isTokenEvent := isOpenAIWSTokenEvent(eventType) + if isTokenEvent { + tokenEventCount++ + } + isTerminalEvent := isOpenAIWSTerminalEvent(eventType) + if isTerminalEvent { + terminalEventCount++ + } + if firstTokenMs == nil && isTokenEvent { + ms := int(time.Since(turnStart).Milliseconds()) + firstTokenMs = &ms + } + if openAIWSEventShouldParseUsage(eventType) { + parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage) + } + + if !clientDisconnected { + if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) { + upstreamMessage = replaceOpenAIWSMessageModel(upstreamMessage, mappedModel, originalModel) + } + if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) { + if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(upstreamMessage); changed { + upstreamMessage = corrected + } + } + if err := writeClientMessage(upstreamMessage); err != nil { + if isOpenAIWSClientDisconnectError(err) { + clientDisconnected = true + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err) + logOpenAIWSModeInfo( + "ingress_ws_client_disconnected_drain account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + } else { + return nil, wrapOpenAIWSIngressTurnError( + "write_client", + fmt.Errorf("write client websocket event: %w", err), + wroteDownstream, + ) + } + } else { + wroteDownstream = true + } + } + if isTerminalEvent { + // 客户端已断连时,上游连接的 session 状态不可信,标记 broken 避免回池复用。 + if clientDisconnected { + lease.MarkBroken() + } + firstTokenMsValue := -1 + if firstTokenMs != nil { + firstTokenMsValue = *firstTokenMs + } + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + time.Since(turnStart).Milliseconds(), + eventCount, + tokenEventCount, + terminalEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + firstTokenMsValue, + clientDisconnected, + ) + } + return &OpenAIForwardResult{ + RequestID: responseID, + Usage: usage, + Model: originalModel, + ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + Duration: time.Since(turnStart), + FirstTokenMs: firstTokenMs, + }, nil + } + } + } + + currentPayload := firstPayload.payloadRaw + currentOriginalModel := firstPayload.originalModel + currentPayloadBytes := firstPayload.payloadBytes + isStrictAffinityTurn := func(payload []byte) bool { + if !storeDisabled { + return false + } + return strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) != "" + } + var sessionLease *openAIWSConnLease + sessionConnID := "" + pinnedSessionConnID := "" + unpinSessionConn := func(connID string) { + connID = strings.TrimSpace(connID) + if connID == "" || pinnedSessionConnID != connID { + return + } + pool.UnpinConn(account.ID, connID) + pinnedSessionConnID = "" + } + pinSessionConn := func(connID string) { + if !storeDisabled { + return + } + connID = strings.TrimSpace(connID) + if connID == "" || pinnedSessionConnID == connID { + return + } + if pinnedSessionConnID != "" { + pool.UnpinConn(account.ID, pinnedSessionConnID) + pinnedSessionConnID = "" + } + if pool.PinConn(account.ID, connID) { + pinnedSessionConnID = connID + } + } + releaseSessionLease := func() { + if sessionLease == nil { + return + } + if dedicatedMode { + // dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。 + sessionLease.MarkBroken() + } + unpinSessionConn(sessionConnID) + sessionLease.Release() + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_upstream_released account_id=%d conn_id=%s", + account.ID, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + ) + } + } + defer releaseSessionLease() + + turn := 1 + turnRetry := 0 + turnPrevRecoveryTried := false + lastTurnFinishedAt := time.Time{} + lastTurnResponseID := "" + lastTurnPayload := []byte(nil) + var lastTurnStrictState *openAIWSIngressPreviousTurnStrictState + lastTurnReplayInput := []json.RawMessage(nil) + lastTurnReplayInputExists := false + currentTurnReplayInput := []json.RawMessage(nil) + currentTurnReplayInputExists := false + skipBeforeTurn := false + resetSessionLease := func(markBroken bool) { + if sessionLease == nil { + return + } + if markBroken { + sessionLease.MarkBroken() + } + releaseSessionLease() + sessionLease = nil + sessionConnID = "" + preferredConnID = "" + } + recoverIngressPrevResponseNotFound := func(relayErr error, turn int, connID string) bool { + if !isOpenAIWSIngressPreviousResponseNotFound(relayErr) { + return false + } + if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() { + return false + } + if isStrictAffinityTurn(currentPayload) { + // Layer 2:严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。 + // 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。 + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_layer2 account_id=%d turn=%d conn_id=%s store_disabled_conn_mode=%s action=drop_previous_response_id_retry", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(storeDisabledConnMode), + ) + } + turnPrevRecoveryTried = true + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + reason := "not_removed" + if dropErr != nil { + reason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(reason), + ) + return false + } + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + ) + return false + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id retry=1", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + retryIngressTurn := func(relayErr error, turn int, connID string) bool { + if !isOpenAIWSIngressTurnRetryable(relayErr) || turnRetry >= 1 { + return false + } + if isStrictAffinityTurn(currentPayload) { + logOpenAIWSModeInfo( + "ingress_ws_turn_retry_skip account_id=%d turn=%d conn_id=%s reason=strict_affinity", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + return false + } + turnRetry++ + logOpenAIWSModeInfo( + "ingress_ws_turn_retry account_id=%d turn=%d retry=%d reason=%s conn_id=%s", + account.ID, + turn, + turnRetry, + truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + for { + if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil { + if err := hooks.BeforeTurn(turn); err != nil { + return err + } + } + skipBeforeTurn = false + currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") + expectedPrev := strings.TrimSpace(lastTurnResponseID) + hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() + // store=false + function_call_output 场景必须有续链锚点。 + // 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。 + if shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled, + turn, + hasFunctionCallOutput, + currentPreviousResponseID, + expectedPrev, + ) { + updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev) + if setPrevErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_function_call_output_prev_infer_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else { + currentPayload = updatedPayload + currentPayloadBytes = len(updatedPayload) + currentPreviousResponseID = expectedPrev + logOpenAIWSModeInfo( + "ingress_ws_function_call_output_prev_infer account_id=%d turn=%d conn_id=%s action=set_previous_response_id previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } + } + nextReplayInput, nextReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence( + lastTurnReplayInput, + lastTurnReplayInputExists, + currentPayload, + currentPreviousResponseID != "", + ) + if replayInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_replay_input_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(replayInputErr.Error(), openAIWSLogValueMaxLen), + ) + currentTurnReplayInput = nil + currentTurnReplayInputExists = false + } else { + currentTurnReplayInput = nextReplayInput + currentTurnReplayInputExists = nextReplayInputExists + } + if storeDisabled && turn > 1 && currentPreviousResponseID != "" { + shouldKeepPreviousResponseID := false + strictReason := "" + var strictErr error + if lastTurnStrictState != nil { + shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseIDWithStrictState( + lastTurnStrictState, + currentPayload, + lastTurnResponseID, + hasFunctionCallOutput, + ) + } else { + shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseID( + lastTurnPayload, + currentPayload, + lastTurnResponseID, + hasFunctionCallOutput, + ) + } + if strictErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s cause=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(strictErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + } else if !shouldKeepPreviousResponseID { + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + dropReason := "not_removed" + if dropErr != nil { + dropReason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + normalizeOpenAIWSLogValue(dropReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + } else { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=set_full_input_error previous_response_id=%s expected_previous_response_id=%s cause=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + hasFunctionCallOutput, + ) + } else { + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_full_create reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + currentPreviousResponseID = "" + } + } + } + } + forcePreferredConn := isStrictAffinityTurn(currentPayload) + if sessionLease == nil { + acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn) + if acquireErr != nil { + return fmt.Errorf("acquire upstream websocket: %w", acquireErr) + } + sessionLease = acquiredLease + sessionConnID = strings.TrimSpace(sessionLease.ConnID()) + if storeDisabled { + pinSessionConn(sessionConnID) + } else { + unpinSessionConn(sessionConnID) + } + } + shouldPreflightPing := turn > 1 && sessionLease != nil && turnRetry == 0 + if shouldPreflightPing && openAIWSIngressPreflightPingIdle > 0 && !lastTurnFinishedAt.IsZero() { + if time.Since(lastTurnFinishedAt) < openAIWSIngressPreflightPingIdle { + shouldPreflightPing = false + } + } + if shouldPreflightPing { + if pingErr := sessionLease.PingWithTimeout(openAIWSConnHealthCheckTO); pingErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_upstream_preflight_ping_fail account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen), + ) + if forcePreferredConn { + if !turnPrevRecoveryTried && currentPreviousResponseID != "" { + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + reason := "not_removed" + if dropErr != nil { + reason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(reason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) + } else { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error previous_response_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) + turnPrevRecoveryTried = true + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + resetSessionLease(true) + skipBeforeTurn = true + continue + } + } + } + resetSessionLease(true) + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream continuation connection is unavailable; please restart the conversation", + pingErr, + ) + } + resetSessionLease(true) + + acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn) + if acquireErr != nil { + return fmt.Errorf("acquire upstream websocket after preflight ping fail: %w", acquireErr) + } + sessionLease = acquiredLease + sessionConnID = strings.TrimSpace(sessionLease.ConnID()) + if storeDisabled { + pinSessionConn(sessionConnID) + } + } + } + connID := sessionConnID + if currentPreviousResponseID != "" { + chainedFromLast := expectedPrev != "" && currentPreviousResponseID == expectedPrev + currentPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(currentPreviousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_turn_chain account_id=%d turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v preferred_conn_id=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(currentPreviousResponseIDKind), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + chainedFromLast, + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"), + turnState != "", + len(turnState), + openAIWSPayloadStringFromRaw(currentPayload, "prompt_cache_key") != "", + storeDisabled, + ) + } + + result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel) + if relayErr != nil { + if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { + continue + } + if retryIngressTurn(relayErr, turn, connID) { + continue + } + finalErr := relayErr + if unwrapped := errors.Unwrap(relayErr); unwrapped != nil { + finalErr = unwrapped + } + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turn, nil, finalErr) + } + sessionLease.MarkBroken() + return finalErr + } + turnRetry = 0 + turnPrevRecoveryTried = false + lastTurnFinishedAt = time.Now() + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turn, result, nil) + } + if result == nil { + return errors.New("websocket turn result is nil") + } + responseID := strings.TrimSpace(result.RequestID) + lastTurnResponseID = responseID + lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload) + lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput) + lastTurnReplayInputExists = currentTurnReplayInputExists + nextStrictState, strictStateErr := buildOpenAIWSIngressPreviousTurnStrictState(currentPayload) + if strictStateErr != nil { + lastTurnStrictState = nil + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_state_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(strictStateErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + lastTurnStrictState = nextStrictState + } + + if responseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) + stateStore.BindResponseConn(responseID, connID, ttl) + } + if stateStore != nil && storeDisabled && sessionHash != "" { + stateStore.BindSessionConn(groupID, sessionHash, connID, s.openAIWSSessionStickyTTL()) + } + if connID != "" { + preferredConnID = connID + } + + nextClientMessage, readErr := readClientMessage() + if readErr != nil { + if isOpenAIWSClientDisconnectError(readErr) { + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + return nil + } + return fmt.Errorf("read client websocket request: %w", readErr) + } + + nextPayload, parseErr := parseClientPayload(nextClientMessage) + if parseErr != nil { + return parseErr + } + if nextPayload.promptCacheKey != "" { + // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接; + // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。 + updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey) + baseAcquireReq.Headers = updatedHeaders + } + if nextPayload.previousResponseID != "" { + expectedPrev := strings.TrimSpace(lastTurnResponseID) + chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev + nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v", + account.ID, + turn, + turn+1, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(nextPreviousResponseIDKind), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + chainedFromLast, + nextPayload.promptCacheKey != "", + storeDisabled, + ) + } + if stateStore != nil && nextPayload.previousResponseID != "" { + if stickyConnID, ok := stateStore.GetResponseConn(nextPayload.previousResponseID); ok { + if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID { + logOpenAIWSModeInfo( + "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), + ) + } else { + preferredConnID = stickyConnID + } + } + } + currentPayload = nextPayload.payloadRaw + currentOriginalModel = nextPayload.originalModel + currentPayloadBytes = nextPayload.payloadBytes + storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account) + if !storeDisabled { + unpinSessionConn(sessionConnID) + } + turn++ + } +} + +func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled +} + +// performOpenAIWSGeneratePrewarm 在 WSv2 下执行可选的 generate=false 预热。 +// 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。 +func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( + ctx context.Context, + lease *openAIWSConnLease, + decision OpenAIWSProtocolDecision, + payload map[string]any, + previousResponseID string, + reqBody map[string]any, + account *Account, + stateStore OpenAIWSStateStore, + groupID int64, +) error { + if s == nil { + return nil + } + if lease == nil || account == nil { + logOpenAIWSModeInfo("prewarm_skip reason=invalid_state has_lease=%v has_account=%v", lease != nil, account != nil) + return nil + } + connID := strings.TrimSpace(lease.ConnID()) + if !s.isOpenAIWSGeneratePrewarmEnabled() { + return nil + } + if decision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + logOpenAIWSModeInfo( + "prewarm_skip account_id=%d conn_id=%s reason=transport_not_v2 transport=%s", + account.ID, + connID, + normalizeOpenAIWSLogValue(string(decision.Transport)), + ) + return nil + } + if strings.TrimSpace(previousResponseID) != "" { + logOpenAIWSModeInfo( + "prewarm_skip account_id=%d conn_id=%s reason=has_previous_response_id previous_response_id=%s", + account.ID, + connID, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + ) + return nil + } + if lease.IsPrewarmed() { + logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=already_prewarmed", account.ID, connID) + return nil + } + if NeedsToolContinuation(reqBody) { + logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=tool_continuation", account.ID, connID) + return nil + } + prewarmStart := time.Now() + logOpenAIWSModeInfo("prewarm_start account_id=%d conn_id=%s", account.ID, connID) + + prewarmPayload := make(map[string]any, len(payload)+1) + for k, v := range payload { + prewarmPayload[k] = v + } + prewarmPayload["generate"] = false + prewarmPayloadJSON := payloadAsJSONBytes(prewarmPayload) + + if err := lease.WriteJSONWithContextTimeout(ctx, prewarmPayload, s.openAIWSWriteTimeout()); err != nil { + lease.MarkBroken() + logOpenAIWSModeInfo( + "prewarm_write_fail account_id=%d conn_id=%s cause=%s", + account.ID, + connID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + return wrapOpenAIWSFallback("prewarm_write", err) + } + logOpenAIWSModeInfo("prewarm_write_sent account_id=%d conn_id=%s payload_bytes=%d", account.ID, connID, len(prewarmPayloadJSON)) + + prewarmResponseID := "" + prewarmEventCount := 0 + prewarmTerminalCount := 0 + for { + message, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) + if readErr != nil { + lease.MarkBroken() + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "prewarm_read_fail account_id=%d conn_id=%s close_status=%s close_reason=%s cause=%s events=%d", + account.ID, + connID, + closeStatus, + closeReason, + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + prewarmEventCount, + ) + return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr) + } + + eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message) + if eventType == "" { + continue + } + prewarmEventCount++ + if prewarmResponseID == "" && eventResponseID != "" { + prewarmResponseID = eventResponseID + } + if prewarmEventCount <= openAIWSPrewarmEventLogHead || eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + logOpenAIWSModeInfo( + "prewarm_event account_id=%d conn_id=%s idx=%d type=%s bytes=%d", + account.ID, + connID, + prewarmEventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(message), + ) + } + + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "OpenAI websocket prewarm error" + } + fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSModeInfo( + "prewarm_error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", + account.ID, + connID, + prewarmEventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + canFallback, + errCode, + errType, + errMessage, + ) + lease.MarkBroken() + if canFallback { + return wrapOpenAIWSFallback("prewarm_"+fallbackReason, errors.New(errMsg)) + } + return wrapOpenAIWSFallback("prewarm_error_event", errors.New(errMsg)) + } + + if isOpenAIWSTerminalEvent(eventType) { + prewarmTerminalCount++ + break + } + } + + lease.MarkPrewarmed() + if prewarmResponseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl)) + stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl) + } + logOpenAIWSModeInfo( + "prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d", + account.ID, + connID, + truncateOpenAIWSLogValue(prewarmResponseID, openAIWSIDValueMaxLen), + prewarmEventCount, + prewarmTerminalCount, + time.Since(prewarmStart).Milliseconds(), + ) + return nil +} + +func payloadAsJSON(payload map[string]any) string { + return string(payloadAsJSONBytes(payload)) +} + +func payloadAsJSONBytes(payload map[string]any) []byte { + if len(payload) == 0 { + return []byte("{}") + } + body, err := json.Marshal(payload) + if err != nil { + return []byte("{}") + } + return body +} + +func isOpenAIWSTerminalEvent(eventType string) bool { + switch strings.TrimSpace(eventType) { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } +} + +func isOpenAIWSTokenEvent(eventType string) bool { + eventType = strings.TrimSpace(eventType) + if eventType == "" { + return false + } + switch eventType { + case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": + return false + } + if strings.Contains(eventType, ".delta") { + return true + } + if strings.HasPrefix(eventType, "response.output_text") { + return true + } + if strings.HasPrefix(eventType, "response.output") { + return true + } + return eventType == "response.completed" || eventType == "response.done" +} + +func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte { + if len(message) == 0 { + return message + } + if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel { + return message + } + if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) { + return message + } + modelValues := gjson.GetManyBytes(message, "model", "response.model") + replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel + replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel + if !replaceModel && !replaceResponseModel { + return message + } + updated := message + if replaceModel { + if next, err := sjson.SetBytes(updated, "model", toModel); err == nil { + updated = next + } + } + if replaceResponseModel { + if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil { + updated = next + } + } + return updated +} + +func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) { + if usage == nil || len(body) == 0 { + return + } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func getOpenAIGroupIDFromContext(c *gin.Context) int64 { + if c == nil { + return 0 + } + value, exists := c.Get("api_key") + if !exists { + return 0 + } + apiKey, ok := value.(*APIKey) + if !ok || apiKey == nil || apiKey.GroupID == nil { + return 0 + } + return *apiKey.GroupID +} + +// SelectAccountByPreviousResponseID 按 previous_response_id 命中账号粘连。 +// 未命中或账号不可用时返回 (nil, nil),由调用方继续走常规调度。 +func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( + ctx context.Context, + groupID *int64, + previousResponseID string, + requestedModel string, + excludedIDs map[int64]struct{}, +) (*AccountSelectionResult, error) { + if s == nil { + return nil, nil + } + responseID := strings.TrimSpace(previousResponseID) + if responseID == "" { + return nil, nil + } + store := s.getOpenAIWSStateStore() + if store == nil { + return nil, nil + } + + accountID, err := store.GetResponseAccount(ctx, derefGroupID(groupID), responseID) + if err != nil || accountID <= 0 { + return nil, nil + } + if excludedIDs != nil { + if _, excluded := excludedIDs[accountID]; excluded { + return nil, nil + } + } + + account, err := s.getSchedulableAccount(ctx, accountID) + if err != nil || account == nil { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + // 非 WSv2 场景(如 force_http/全局关闭)不应使用 previous_response_id 粘连, + // 以保持“回滚到 HTTP”后的历史行为一致性。 + if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return nil, nil + } + if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + if requestedModel != "" && !account.IsModelSupported(requestedModel) { + return nil, nil + } + + result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if acquireErr == nil && result.Acquired { + logOpenAIWSBindResponseAccountWarn( + derefGroupID(groupID), + accountID, + responseID, + store.BindResponseAccount(ctx, derefGroupID(groupID), responseID, accountID, s.openAIWSResponseStickyTTL()), + ) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + cfg := s.schedulingConfig() + if s.concurrencyService != nil { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + return nil, nil +} + +func classifyOpenAIWSAcquireError(err error) string { + if err == nil { + return "acquire_conn" + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) { + switch dialErr.StatusCode { + case 426: + return "upgrade_required" + case 401, 403: + return "auth_failed" + case 429: + return "upstream_rate_limited" + } + if dialErr.StatusCode >= 500 { + return "upstream_5xx" + } + return "dial_failed" + } + if errors.Is(err, errOpenAIWSConnQueueFull) { + return "conn_queue_full" + } + if errors.Is(err, errOpenAIWSPreferredConnUnavailable) { + return "preferred_conn_unavailable" + } + if errors.Is(err, context.DeadlineExceeded) { + return "acquire_timeout" + } + return "acquire_conn" +} + +func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + switch code { + case "upgrade_required": + return "upgrade_required", true + case "websocket_not_supported", "websocket_unsupported": + return "ws_unsupported", true + case "websocket_connection_limit_reached": + return "ws_connection_limit_reached", true + case "previous_response_not_found": + return "previous_response_not_found", true + } + if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { + return "upgrade_required", true + } + if strings.Contains(errType, "upgrade") { + return "upgrade_required", true + } + if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") { + return "ws_unsupported", true + } + if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") { + return "ws_connection_limit_reached", true + } + if strings.Contains(msg, "previous_response_not_found") || + (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) { + return "previous_response_not_found", true + } + if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") { + return "upstream_error_event", true + } + return "event_error", false +} + +func classifyOpenAIWSErrorEvent(message []byte) (string, bool) { + if len(message) == 0 { + return "event_error", false + } + return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + switch { + case strings.Contains(errType, "invalid_request"), + strings.Contains(code, "invalid_request"), + strings.Contains(code, "bad_request"), + code == "previous_response_not_found": + return http.StatusBadRequest + case strings.Contains(errType, "authentication"), + strings.Contains(code, "invalid_api_key"), + strings.Contains(code, "unauthorized"): + return http.StatusUnauthorized + case strings.Contains(errType, "permission"), + strings.Contains(code, "forbidden"): + return http.StatusForbidden + case strings.Contains(errType, "rate_limit"), + strings.Contains(code, "rate_limit"), + strings.Contains(code, "insufficient_quota"): + return http.StatusTooManyRequests + default: + return http.StatusBadGateway + } +} + +func openAIWSErrorHTTPStatus(message []byte) int { + if len(message) == 0 { + return http.StatusBadGateway + } + codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message) + return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) +} + +func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration { + if s == nil || s.cfg == nil { + return 30 * time.Second + } + seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds + if seconds <= 0 { + return 0 + } + return time.Duration(seconds) * time.Second +} + +func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool { + if s == nil || accountID <= 0 { + return false + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return false + } + rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID) + if !ok || rawUntil == nil { + return false + } + until, ok := rawUntil.(time.Time) + if !ok || until.IsZero() { + s.openaiWSFallbackUntil.Delete(accountID) + return false + } + if time.Now().Before(until) { + return true + } + s.openaiWSFallbackUntil.Delete(accountID) + return false +} + +func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) { + if s == nil || accountID <= 0 { + return + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return + } + s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown)) +} + +func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) { + if s == nil || accountID <= 0 { + return + } + s.openaiWSFallbackUntil.Delete(accountID) +} diff --git a/backend/internal/service/openai_ws_forwarder_benchmark_test.go b/backend/internal/service/openai_ws_forwarder_benchmark_test.go new file mode 100644 index 00000000..bd03ab5a --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_benchmark_test.go @@ -0,0 +1,127 @@ +package service + +import ( + "fmt" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +var ( + benchmarkOpenAIWSPayloadJSONSink string + benchmarkOpenAIWSStringSink string + benchmarkOpenAIWSBoolSink bool + benchmarkOpenAIWSBytesSink []byte +) + +func BenchmarkOpenAIWSForwarderHotPath(b *testing.B) { + cfg := &config.Config{} + svc := &OpenAIGatewayService{cfg: cfg} + account := &Account{ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + reqBody := benchmarkOpenAIWSHotPathRequest() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + payload := svc.buildOpenAIWSCreatePayload(reqBody, account) + _, _ = applyOpenAIWSRetryPayloadStrategy(payload, 2) + setOpenAIWSTurnMetadata(payload, `{"trace":"bench","turn":"1"}`) + + benchmarkOpenAIWSStringSink = openAIWSPayloadString(payload, "previous_response_id") + benchmarkOpenAIWSBoolSink = payload["tools"] != nil + benchmarkOpenAIWSStringSink = summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN) + benchmarkOpenAIWSStringSink = summarizeOpenAIWSInput(payload["input"]) + benchmarkOpenAIWSPayloadJSONSink = payloadAsJSON(payload) + } +} + +func benchmarkOpenAIWSHotPathRequest() map[string]any { + tools := make([]map[string]any, 0, 24) + for i := 0; i < 24; i++ { + tools = append(tools, map[string]any{ + "type": "function", + "name": fmt.Sprintf("tool_%02d", i), + "description": "benchmark tool schema", + "parameters": map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + "limit": map[string]any{"type": "number"}, + }, + "required": []string{"query"}, + }, + }) + } + + input := make([]map[string]any, 0, 16) + for i := 0; i < 16; i++ { + input = append(input, map[string]any{ + "role": "user", + "type": "message", + "content": fmt.Sprintf("benchmark message %d", i), + }) + } + + return map[string]any{ + "type": "response.create", + "model": "gpt-5.3-codex", + "input": input, + "tools": tools, + "parallel_tool_calls": true, + "previous_response_id": "resp_benchmark_prev", + "prompt_cache_key": "bench-cache-key", + "reasoning": map[string]any{"effort": "medium"}, + "instructions": "benchmark instructions", + "store": false, + } +} + +func BenchmarkOpenAIWSEventEnvelopeParse(b *testing.B) { + event := []byte(`{"type":"response.completed","response":{"id":"resp_bench_1","model":"gpt-5.1","usage":{"input_tokens":12,"output_tokens":8}}}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + eventType, responseID, response := parseOpenAIWSEventEnvelope(event) + benchmarkOpenAIWSStringSink = eventType + benchmarkOpenAIWSStringSink = responseID + benchmarkOpenAIWSBoolSink = response.Exists() + } +} + +func BenchmarkOpenAIWSErrorEventFieldReuse(b *testing.B) { + event := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(event) + benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw) + code, errType, errMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw) + benchmarkOpenAIWSStringSink = code + benchmarkOpenAIWSStringSink = errType + benchmarkOpenAIWSStringSink = errMsg + benchmarkOpenAIWSBoolSink = openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) > 0 + } +} + +func BenchmarkReplaceOpenAIWSMessageModel_NoMatchFastPath(b *testing.B) { + event := []byte(`{"type":"response.output_text.delta","delta":"hello world"}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model") + } +} + +func BenchmarkReplaceOpenAIWSMessageModel_DualReplace(b *testing.B) { + event := []byte(`{"type":"response.completed","model":"gpt-5.1","response":{"id":"resp_1","model":"gpt-5.1"}}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model") + } +} diff --git a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go new file mode 100644 index 00000000..76167603 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go @@ -0,0 +1,73 @@ +package service + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseOpenAIWSEventEnvelope(t *testing.T) { + eventType, responseID, response := parseOpenAIWSEventEnvelope([]byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`)) + require.Equal(t, "response.completed", eventType) + require.Equal(t, "resp_1", responseID) + require.True(t, response.Exists()) + require.Equal(t, `{"id":"resp_1","model":"gpt-5.1"}`, response.Raw) + + eventType, responseID, response = parseOpenAIWSEventEnvelope([]byte(`{"type":"response.delta","id":"evt_1"}`)) + require.Equal(t, "response.delta", eventType) + require.Equal(t, "evt_1", responseID) + require.False(t, response.Exists()) +} + +func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) { + usage := &OpenAIUsage{} + parseOpenAIWSResponseUsageFromCompletedEvent( + []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`), + usage, + ) + require.Equal(t, 11, usage.InputTokens) + require.Equal(t, 7, usage.OutputTokens) + require.Equal(t, 3, usage.CacheReadInputTokens) +} + +func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) { + message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`) + codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + + wrappedReason, wrappedRecoverable := classifyOpenAIWSErrorEvent(message) + rawReason, rawRecoverable := classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw) + require.Equal(t, wrappedReason, rawReason) + require.Equal(t, wrappedRecoverable, rawRecoverable) + + wrappedStatus := openAIWSErrorHTTPStatus(message) + rawStatus := openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) + require.Equal(t, wrappedStatus, rawStatus) + require.Equal(t, http.StatusBadRequest, rawStatus) + + wrappedCode, wrappedType, wrappedMsg := summarizeOpenAIWSErrorEventFields(message) + rawCode, rawType, rawMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw) + require.Equal(t, wrappedCode, rawCode) + require.Equal(t, wrappedType, rawType) + require.Equal(t, wrappedMsg, rawMsg) +} + +func TestOpenAIWSMessageLikelyContainsToolCalls(t *testing.T) { + require.False(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_text.delta","delta":"hello"}`))) + require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"tool_calls":[{"id":"tc1"}]}}`))) + require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"function_call"}}`))) +} + +func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) { + noModel := []byte(`{"type":"response.output_text.delta","delta":"hello"}`) + require.Equal(t, string(noModel), string(replaceOpenAIWSMessageModel(noModel, "gpt-5.1", "custom-model"))) + + rootOnly := []byte(`{"type":"response.created","model":"gpt-5.1"}`) + require.Equal(t, `{"type":"response.created","model":"custom-model"}`, string(replaceOpenAIWSMessageModel(rootOnly, "gpt-5.1", "custom-model"))) + + responseOnly := []byte(`{"type":"response.completed","response":{"model":"gpt-5.1"}}`) + require.Equal(t, `{"type":"response.completed","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(responseOnly, "gpt-5.1", "custom-model"))) + + both := []byte(`{"model":"gpt-5.1","response":{"model":"gpt-5.1"}}`) + require.Equal(t, `{"model":"custom-model","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(both, "gpt-5.1", "custom-model"))) +} diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go new file mode 100644 index 00000000..5a3c12c3 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -0,0 +1,2483 @@ +package service + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossTurns(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 114, + Name: "openai-ingress-session-lease", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + turnWSModeCh := make(chan bool, 2) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { + if turnErr == nil && result != nil { + turnWSModeCh <- result.OpenAIWSMode + } + }, + } + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) + firstTurnEvent := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(firstTurnEvent, "type").String()) + require.Equal(t, "resp_ingress_turn_1", gjson.GetBytes(firstTurnEvent, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_ingress_turn_1"}`) + secondTurnEvent := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(secondTurnEvent, "type").String()) + require.Equal(t, "resp_ingress_turn_2", gjson.GetBytes(secondTurnEvent, "response.id").String()) + require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式") + require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式") + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + metrics := svc.SnapshotOpenAIWSPoolMetrics() + require.Equal(t, int64(1), metrics.AcquireTotal, "同一 ingress 会话多 turn 应只获取一次上游 lease") + require.Equal(t, 1, captureDialer.DialCount(), "同一 ingress 会话应保持同一上游连接") + require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + upstreamConn1 := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + upstreamConn2 := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{upstreamConn1, upstreamConn2}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 441, + Name: "openai-ingress-dedicated", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + + serverErrCh := make(chan error, 2) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + runSingleTurnSession := func(expectedResponseID string) { + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + msgType, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + require.Equal(t, expectedResponseID, gjson.GetBytes(event, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + } + + runSingleTurnSession("resp_dedicated_1") + runSingleTurnSession("resp_dedicated_2") + + require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: newOpenAIWSConnPool(cfg), + } + + account := &Account{ + ID: 442, + Name: "openai-ingress-off", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeOff, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + select { + case serverErr := <-serverErrCh: + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Equal(t, "websocket mode is disabled for this account", closeErr.Reason()) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropToFullCreate(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 140, + Name: "openai-ingress-prev-preflight-rewrite", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_preflight_rewrite_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_preflight_rewrite_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount(), "严格增量不成立时应在同一连接内降级为 full create") + require.Len(t, captureConn.writes, 2) + secondWrite := requestToJSONString(captureConn.writes[1]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格增量不成立时应移除 previous_response_id,改为 full create") + require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "严格降级为 full create 时应重放完整 input 上下文") + require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) + require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropBeforePreflightPingFailReconnects(t *testing.T) { + gin.SetMode(gin.TestMode) + prevPreflightPingIdle := openAIWSIngressPreflightPingIdle + openAIWSIngressPreflightPingIdle = 0 + defer func() { + openAIWSIngressPreflightPingIdle = prevPreflightPingIdle + }() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSPreflightFailConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 142, + Name: "openai-ingress-prev-strict-drop-before-ping", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_ping_drop_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_ping_drop_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 严格降级后预检换连超时") + } + + require.Equal(t, 2, dialer.DialCount(), "严格降级为 full create 后,预检 ping 失败应允许换连") + require.Equal(t, 1, firstConn.WriteCount(), "首连接在预检失败后不应继续发送第二轮") + require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping") + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1) + secondWrite := requestToJSONString(secondWrites[0]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格降级后重试应移除 previous_response_id") + require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array())) + require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) + require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreEnabledSkipsStrictPrevResponseEval(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 143, + Name: "openai-ingress-store-enabled-skip-strict", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true}`) + firstTurn := readMessage() + require.Equal(t, "resp_store_enabled_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true,"previous_response_id":"resp_stale_external"}`) + secondTurn := readMessage() + require.Equal(t, "resp_store_enabled_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 store=true 场景 websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "store=true 场景不应触发 store-disabled strict 规则") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponsePreflightSkipForFunctionCallOutput(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 141, + Name: "openai-ingress-prev-preflight-skip-fco", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false}`) + firstTurn := readMessage() + require.Equal(t, "resp_preflight_skip_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_preflight_skip_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 场景不应预改写 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 143, + Name: "openai-ingress-fco-auto-prev", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_auto_prev_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_1","output":"ok"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_auto_prev_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.Equal(t, "resp_auto_prev_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 缺失 previous_response_id 时应回填上一轮响应 ID") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenLastResponseIDMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 144, + Name: "openai-ingress-fco-auto-prev-skip", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(firstTurn, "type").String()) + require.Empty(t, gjson.GetBytes(firstTurn, "response.id").String(), "首轮响应不返回 response.id,模拟无法推导续链锚点") + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_skip_1","output":"ok"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_auto_prev_skip_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) { + gin.SetMode(gin.TestMode) + prevPreflightPingIdle := openAIWSIngressPreflightPingIdle + openAIWSIngressPreflightPingIdle = 0 + defer func() { + openAIWSIngressPreflightPingIdle = prevPreflightPingIdle + }() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSPreflightFailConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 116, + Name: "openai-ingress-preflight-ping", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_ping_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_ping_1"}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_ping_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 前 ping 失败应触发换连") + require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续向旧连接发送第二轮 turn") + require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应对旧连接执行 preflight ping") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreflightPingFailAutoRecoveryReconnects(t *testing.T) { + gin.SetMode(gin.TestMode) + prevPreflightPingIdle := openAIWSIngressPreflightPingIdle + openAIWSIngressPreflightPingIdle = 0 + defer func() { + openAIWSIngressPreflightPingIdle = prevPreflightPingIdle + }() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSPreflightFailConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 121, + Name: "openai-ingress-preflight-ping-strict-affinity", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_ping_strict_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_strict_1","input":[{"type":"input_text","text":"world"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_ping_strict_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 严格亲和自动恢复后结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "严格亲和 preflight ping 失败后应自动降级并换连重放") + require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续在旧连接写第二轮") + require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping") + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1) + secondWrite := requestToJSONString(secondWrites[0]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "自动恢复重放应移除 previous_response_id") + require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "自动恢复重放应使用完整 input 上下文") + require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) + require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSWriteFailAfterFirstTurnConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 117, + Name: "openai-ingress-write-retry", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + var hooksMu sync.Mutex + beforeTurnCalls := make(map[int]int) + afterTurnCalls := make(map[int]int) + hooks := &OpenAIWSIngressHooks{ + BeforeTurn: func(turn int) error { + hooksMu.Lock() + beforeTurnCalls[turn]++ + hooksMu.Unlock() + return nil + }, + AfterTurn: func(turn int, _ *OpenAIForwardResult, _ error) { + hooksMu.Lock() + afterTurnCalls[turn]++ + hooksMu.Unlock() + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_write_retry_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_write_retry_1"}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_write_retry_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 上游写失败且未写下游时应自动重试并换连") + hooksMu.Lock() + beforeTurn1 := beforeTurnCalls[1] + beforeTurn2 := beforeTurnCalls[2] + afterTurn1 := afterTurnCalls[1] + afterTurn2 := afterTurnCalls[2] + hooksMu.Unlock() + require.Equal(t, 1, beforeTurn1, "首轮 turn BeforeTurn 应执行一次") + require.Equal(t, 1, beforeTurn2, "同一 turn 重试不应重复触发 BeforeTurn") + require.Equal(t, 1, afterTurn1, "首轮 turn AfterTurn 应执行一次") + require.Equal(t, 1, afterTurn2, "第二轮 turn AfterTurn 应执行一次") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoversByDroppingPrevID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":""}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 118, + Name: "openai-ingress-prev-recovery", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_seed_anchor"}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_prev_recover_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_recover_1"}`) + secondTurn := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(secondTurn, "type").String()) + require.Equal(t, "resp_turn_prev_recover_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应触发换连重试") + + firstConn.mu.Lock() + firstWrites := append([]map[string]any(nil), firstConn.writes...) + firstConn.mu.Unlock() + require.Len(t, firstWrites, 2, "首个连接应处理首轮与失败的第二轮请求") + require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists(), "失败轮次首发请求应包含 previous_response_id") + + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1, "恢复重试应在第二个连接发送一次请求") + require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreviousResponseNotFoundLayer2Recovery(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"missing strict anchor"}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 122, + Name: "openai-ingress-prev-strict-layer2", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_prev_strict_recover_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","previous_response_id":"resp_turn_prev_strict_recover_1","input":[{"type":"input_text","text":"world"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_prev_strict_recover_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 严格亲和 Layer2 恢复结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "严格亲和链路命中 previous_response_not_found 应触发 Layer2 恢复重试") + + firstConn.mu.Lock() + firstWrites := append([]map[string]any(nil), firstConn.writes...) + firstConn.mu.Unlock() + require.Len(t, firstWrites, 2, "首连接应收到首轮请求和失败的续链请求") + require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists()) + + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1, "Layer2 恢复应仅重放一次") + secondWrite := requestToJSONString(secondWrites[0]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "Layer2 恢复重放应移除 previous_response_id") + require.True(t, gjson.Get(secondWrite, "store").Exists(), "Layer2 恢复不应改变 store 标志") + require.False(t, gjson.Get(secondWrite, "store").Bool()) + require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "Layer2 恢复应重放完整 input 上下文") + require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) + require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoveryRemovesDuplicatePrevID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"first missing"}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 120, + Name: "openai-ingress-prev-recovery-once", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_prev_once_1", gjson.GetBytes(firstTurn, "response.id").String()) + + // duplicate previous_response_id: 恢复重试时应删除所有重复键,避免再次 previous_response_not_found。 + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_once_1","input":[],"previous_response_id":"resp_turn_prev_duplicate"}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_prev_once_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应只重试一次") + + firstConn.mu.Lock() + firstWrites := append([]map[string]any(nil), firstConn.writes...) + firstConn.mu.Unlock() + require.Len(t, firstWrites, 2) + require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists()) + + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1) + require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "重复键场景恢复重试后不应保留 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_RejectsMessageIDAsPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 119, + Name: "openai-ingress-prev-validation", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456"}`)) + cancelWrite() + require.NoError(t, err) + + select { + case serverErr := <-serverErrCh: + require.Error(t, serverErr) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Contains(t, closeErr.Reason(), "previous_response_id must be a response.id") + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } +} + +type openAIWSQueueDialer struct { + mu sync.Mutex + conns []openAIWSClientConn + dialCount int +} + +func (d *openAIWSQueueDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + defer d.mu.Unlock() + d.dialCount++ + if len(d.conns) == 0 { + return nil, 503, nil, errors.New("no test conn") + } + conn := d.conns[0] + if len(d.conns) > 1 { + d.conns = d.conns[1:] + } + return conn, 0, nil, nil +} + +func (d *openAIWSQueueDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSPreflightFailConn struct { + mu sync.Mutex + events [][]byte + pingFails bool + writeCount int + pingCount int +} + +func (c *openAIWSPreflightFailConn) WriteJSON(context.Context, any) error { + c.mu.Lock() + c.writeCount++ + c.mu.Unlock() + return nil +} + +func (c *openAIWSPreflightFailConn) ReadMessage(context.Context) ([]byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + if len(c.events) == 0 { + return nil, io.EOF + } + event := c.events[0] + c.events = c.events[1:] + if len(c.events) == 0 { + c.pingFails = true + } + return event, nil +} + +func (c *openAIWSPreflightFailConn) Ping(context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + c.pingCount++ + if c.pingFails { + return errors.New("preflight ping failed") + } + return nil +} + +func (c *openAIWSPreflightFailConn) Close() error { + return nil +} + +func (c *openAIWSPreflightFailConn) WriteCount() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.writeCount +} + +func (c *openAIWSPreflightFailConn) PingCount() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.pingCount +} + +type openAIWSWriteFailAfterFirstTurnConn struct { + mu sync.Mutex + events [][]byte + failOnWrite bool +} + +func (c *openAIWSWriteFailAfterFirstTurnConn) WriteJSON(context.Context, any) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.failOnWrite { + return errors.New("write failed on stale conn") + } + return nil +} + +func (c *openAIWSWriteFailAfterFirstTurnConn) ReadMessage(context.Context) ([]byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + if len(c.events) == 0 { + return nil, io.EOF + } + event := c.events[0] + c.events = c.events[1:] + if len(c.events) == 0 { + c.failOnWrite = true + } + return event, nil +} + +func (c *openAIWSWriteFailAfterFirstTurnConn) Ping(context.Context) error { + return nil +} + +func (c *openAIWSWriteFailAfterFirstTurnConn) Close() error { + return nil +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnectStillDrainsUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + // 多个上游事件:前几个为非 terminal 事件,最后一个为 terminal。 + // 第一个事件延迟 250ms 让客户端 RST 有时间传播,使 writeClientMessage 可靠失败。 + captureConn := &openAIWSCaptureConn{ + readDelays: []time.Duration{250 * time.Millisecond, 0, 0}, + events: [][]byte{ + []byte(`{"type":"response.created","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1"}}`), + []byte(`{"type":"response.output_item.added","response":{"id":"resp_ingress_disconnect"}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 115, + Name: "openai-ingress-client-disconnect", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "model_mapping": map[string]any{ + "custom-original-model": "gpt-5.1", + }, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + resultCh := make(chan *OpenAIForwardResult, 1) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { + if turnErr == nil && result != nil { + resultCh <- result + } + }, + } + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false}`)) + cancelWrite() + require.NoError(t, err) + // 立即关闭客户端,模拟客户端在 relay 期间断连。 + require.NoError(t, clientConn.CloseNow(), "模拟 ingress 客户端提前断连") + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr, "客户端断连后应继续 drain 上游直到 terminal 或正常结束") + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + select { + case result := <-resultCh: + require.Equal(t, "resp_ingress_disconnect", result.RequestID) + require.Equal(t, 2, result.Usage.InputTokens) + require.Equal(t, 1, result.Usage.OutputTokens) + case <-time.After(2 * time.Second): + t.Fatal("未收到断连后的 turn 结果回调") + } +} diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go new file mode 100644 index 00000000..ff35cb01 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go @@ -0,0 +1,714 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "io" + "net" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestIsOpenAIWSClientDisconnectError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + {name: "nil", err: nil, want: false}, + {name: "io_eof", err: io.EOF, want: true}, + {name: "net_closed", err: net.ErrClosed, want: true}, + {name: "context_canceled", err: context.Canceled, want: true}, + {name: "ws_normal_closure", err: coderws.CloseError{Code: coderws.StatusNormalClosure}, want: true}, + {name: "ws_going_away", err: coderws.CloseError{Code: coderws.StatusGoingAway}, want: true}, + {name: "ws_no_status", err: coderws.CloseError{Code: coderws.StatusNoStatusRcvd}, want: true}, + {name: "ws_abnormal_1006", err: coderws.CloseError{Code: coderws.StatusAbnormalClosure}, want: true}, + {name: "ws_policy_violation", err: coderws.CloseError{Code: coderws.StatusPolicyViolation}, want: false}, + {name: "wrapped_eof_message", err: errors.New("failed to get reader: failed to read frame header: EOF"), want: true}, + {name: "connection_reset_by_peer", err: errors.New("failed to read frame header: read tcp 127.0.0.1:1234->127.0.0.1:5678: read: connection reset by peer"), want: true}, + {name: "broken_pipe", err: errors.New("write tcp 127.0.0.1:1234->127.0.0.1:5678: write: broken pipe"), want: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, isOpenAIWSClientDisconnectError(tt.err)) + }) + } +} + +func TestIsOpenAIWSIngressPreviousResponseNotFound(t *testing.T) { + t.Parallel() + + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(nil)) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(errors.New("plain error"))) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound( + wrapOpenAIWSIngressTurnError("read_upstream", errors.New("upstream read failed"), false), + )) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound( + wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), true), + )) + require.True(t, isOpenAIWSIngressPreviousResponseNotFound( + wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), false), + )) +} + +func TestOpenAIWSIngressPreviousResponseRecoveryEnabled(t *testing.T) { + t.Parallel() + + var nilService *OpenAIGatewayService + require.True(t, nilService.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil service should default to enabled") + + svcWithNilCfg := &OpenAIGatewayService{} + require.True(t, svcWithNilCfg.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil config should default to enabled") + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + } + require.False(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled(), "explicit config default should be false") + + svc.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + require.True(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled()) +} + +func TestDropPreviousResponseIDFromRawPayload(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + updated, removed, err := dropPreviousResponseIDFromRawPayload(nil) + require.NoError(t, err) + require.False(t, removed) + require.Empty(t, updated) + }) + + t.Run("payload_without_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.False(t, removed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("normal_delete_success", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) + + t.Run("duplicate_keys_are_removed", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_a","input":[],"previous_response_id":"resp_b"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) + + t.Run("nil_delete_fn_uses_default_delete_logic", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, nil) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) + + t.Run("delete_error", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, func(_ []byte, _ string) ([]byte, error) { + return nil, errors.New("delete failed") + }) + require.Error(t, err) + require.False(t, removed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("malformed_json_is_still_best_effort_deleted", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_abc"`) + require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists()) + + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) +} + +func TestAlignStoreDisabledPreviousResponseID(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + updated, changed, err := alignStoreDisabledPreviousResponseID(nil, "resp_target") + require.NoError(t, err) + require.False(t, changed) + require.Empty(t, updated) + }) + + t.Run("empty_expected", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_old"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "") + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("missing_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("already_aligned", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_target"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + }) + + t.Run("mismatch_rewrites_to_expected", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_old","input":[]}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + }) + + t.Run("duplicate_keys_rewrites_to_single_expected", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_old_1","input":[],"previous_response_id":"resp_old_2"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + }) +} + +func TestSetPreviousResponseIDToRawPayload(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + updated, err := setPreviousResponseIDToRawPayload(nil, "resp_target") + require.NoError(t, err) + require.Empty(t, updated) + }) + + t.Run("empty_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, err := setPreviousResponseIDToRawPayload(payload, "") + require.NoError(t, err) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("set_previous_response_id_when_missing", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, err := setPreviousResponseIDToRawPayload(payload, "resp_target") + require.NoError(t, err) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String()) + }) + + t.Run("overwrite_existing_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_old"}`) + updated, err := setPreviousResponseIDToRawPayload(payload, "resp_new") + require.NoError(t, err) + require.Equal(t, "resp_new", gjson.GetBytes(updated, "previous_response_id").String()) + }) +} + +func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + storeDisabled bool + turn int + hasFunctionCallOutput bool + currentPreviousResponse string + expectedPrevious string + want bool + }{ + { + name: "infer_when_all_conditions_match", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: "resp_1", + want: true, + }, + { + name: "skip_when_store_enabled", + storeDisabled: false, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_on_first_turn", + storeDisabled: true, + turn: 1, + hasFunctionCallOutput: true, + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_without_function_call_output", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: false, + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_when_request_already_has_previous_response_id", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + currentPreviousResponse: "resp_client", + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_when_last_turn_response_id_missing", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: "", + want: false, + }, + { + name: "trim_whitespace_before_judgement", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: " resp_2 ", + want: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := shouldInferIngressFunctionCallOutputPreviousResponseID( + tt.storeDisabled, + tt.turn, + tt.hasFunctionCallOutput, + tt.currentPreviousResponse, + tt.expectedPrevious, + ) + require.Equal(t, tt.want, got) + }) + } +} + +func TestOpenAIWSInputIsPrefixExtended(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + previous []byte + current []byte + want bool + expectErr bool + }{ + { + name: "both_missing_input", + previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`), + current: []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_1"}`), + want: true, + }, + { + name: "previous_missing_current_empty_array", + previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`), + current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`), + want: true, + }, + { + name: "previous_missing_current_non_empty_array", + previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`), + current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"input_text","text":"hello"}]}`), + want: false, + }, + { + name: "array_prefix_match", + previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`), + current: []byte(`{"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]}`), + want: true, + }, + { + name: "array_prefix_mismatch", + previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`), + current: []byte(`{"input":[{"type":"input_text","text":"different"}]}`), + want: false, + }, + { + name: "current_shorter_than_previous", + previous: []byte(`{"input":[{"type":"input_text","text":"a"},{"type":"input_text","text":"b"}]}`), + current: []byte(`{"input":[{"type":"input_text","text":"a"}]}`), + want: false, + }, + { + name: "previous_has_input_current_missing", + previous: []byte(`{"input":[{"type":"input_text","text":"a"}]}`), + current: []byte(`{"model":"gpt-5.1"}`), + want: false, + }, + { + name: "input_string_treated_as_single_item", + previous: []byte(`{"input":"hello"}`), + current: []byte(`{"input":"hello"}`), + want: true, + }, + { + name: "current_invalid_input_json", + previous: []byte(`{"input":[]}`), + current: []byte(`{"input":[}`), + expectErr: true, + }, + { + name: "invalid_input_json", + previous: []byte(`{"input":[}`), + current: []byte(`{"input":[]}`), + expectErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := openAIWSInputIsPrefixExtended(tt.previous, tt.current) + if tt.expectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func TestNormalizeOpenAIWSJSONForCompare(t *testing.T) { + t.Parallel() + + normalized, err := normalizeOpenAIWSJSONForCompare([]byte(`{"b":2,"a":1}`)) + require.NoError(t, err) + require.Equal(t, `{"a":1,"b":2}`, string(normalized)) + + _, err = normalizeOpenAIWSJSONForCompare([]byte(" ")) + require.Error(t, err) + + _, err = normalizeOpenAIWSJSONForCompare([]byte(`{"a":`)) + require.Error(t, err) +} + +func TestNormalizeOpenAIWSJSONForCompareOrRaw(t *testing.T) { + t.Parallel() + + require.Equal(t, `{"a":1,"b":2}`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"b":2,"a":1}`)))) + require.Equal(t, `{"a":`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"a":`)))) +} + +func TestNormalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(t *testing.T) { + t.Parallel() + + normalized, err := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID( + []byte(`{"model":"gpt-5.1","input":[1],"previous_response_id":"resp_x","metadata":{"b":2,"a":1}}`), + ) + require.NoError(t, err) + require.False(t, gjson.GetBytes(normalized, "input").Exists()) + require.False(t, gjson.GetBytes(normalized, "previous_response_id").Exists()) + require.Equal(t, float64(1), gjson.GetBytes(normalized, "metadata.a").Float()) + + _, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(nil) + require.Error(t, err) + + _, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID([]byte(`[]`)) + require.Error(t, err) +} + +func TestOpenAIWSExtractNormalizedInputSequence(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence(nil) + require.NoError(t, err) + require.False(t, exists) + require.Nil(t, items) + }) + + t.Run("input_missing", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"type":"response.create"}`)) + require.NoError(t, err) + require.False(t, exists) + require.Nil(t, items) + }) + + t.Run("input_array", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[{"type":"input_text","text":"hello"}]}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + }) + + t.Run("input_object", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":{"type":"input_text","text":"hello"}}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + }) + + t.Run("input_string", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":"hello"}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, `"hello"`, string(items[0])) + }) + + t.Run("input_number", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":42}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "42", string(items[0])) + }) + + t.Run("input_bool", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":true}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "true", string(items[0])) + }) + + t.Run("input_null", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":null}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "null", string(items[0])) + }) + + t.Run("input_invalid_array_json", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[}`)) + require.Error(t, err) + require.True(t, exists) + require.Nil(t, items) + }) +} + +func TestShouldKeepIngressPreviousResponseID(t *testing.T) { + t.Parallel() + + previousPayload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "tools":[{"type":"function","name":"tool_a"}], + "input":[{"type":"input_text","text":"hello"}] + }`) + currentStrictPayload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "tools":[{"name":"tool_a","type":"function"}], + "previous_response_id":"resp_turn_1", + "input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}] + }`) + + t.Run("strict_incremental_keep", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "strict_incremental_ok", reason) + }) + + t.Run("missing_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "missing_previous_response_id", reason) + }) + + t.Run("missing_last_turn_response_id", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "missing_last_turn_response_id", reason) + }) + + t.Run("previous_response_id_mismatch", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "previous_response_id_mismatch", reason) + }) + + t.Run("missing_previous_turn_payload", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "missing_previous_turn_payload", reason) + }) + + t.Run("non_input_changed", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1-mini", + "store":false, + "tools":[{"type":"function","name":"tool_a"}], + "previous_response_id":"resp_turn_1", + "input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "non_input_changed", reason) + }) + + t.Run("delta_input_keeps_previous_response_id", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "tools":[{"type":"function","name":"tool_a"}], + "previous_response_id":"resp_turn_1", + "input":[{"type":"input_text","text":"different"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "strict_incremental_ok", reason) + }) + + t.Run("function_call_output_keeps_previous_response_id", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "previous_response_id":"resp_external", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "has_function_call_output", reason) + }) + + t.Run("non_input_compare_error", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false) + require.Error(t, err) + require.False(t, keep) + require.Equal(t, "non_input_compare_error", reason) + }) + + t.Run("current_payload_compare_error", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false) + require.Error(t, err) + require.False(t, keep) + require.Equal(t, "non_input_compare_error", reason) + }) +} + +func TestBuildOpenAIWSReplayInputSequence(t *testing.T) { + t.Parallel() + + lastFull := []json.RawMessage{ + json.RawMessage(`{"type":"input_text","text":"hello"}`), + } + + t.Run("no_previous_response_id_use_current", func(t *testing.T) { + items, exists, err := buildOpenAIWSReplayInputSequence( + lastFull, + true, + []byte(`{"input":[{"type":"input_text","text":"new"}]}`), + false, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "new", gjson.GetBytes(items[0], "text").String()) + }) + + t.Run("previous_response_id_delta_append", func(t *testing.T) { + items, exists, err := buildOpenAIWSReplayInputSequence( + lastFull, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 2) + require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String()) + require.Equal(t, "world", gjson.GetBytes(items[1], "text").String()) + }) + + t.Run("previous_response_id_full_input_replace", func(t *testing.T) { + items, exists, err := buildOpenAIWSReplayInputSequence( + lastFull, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 2) + require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String()) + require.Equal(t, "world", gjson.GetBytes(items[1], "text").String()) + }) +} + +func TestSetOpenAIWSPayloadInputSequence(t *testing.T) { + t.Parallel() + + t.Run("set_items", func(t *testing.T) { + original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`) + items := []json.RawMessage{ + json.RawMessage(`{"type":"input_text","text":"hello"}`), + json.RawMessage(`{"type":"input_text","text":"world"}`), + } + updated, err := setOpenAIWSPayloadInputSequence(original, items, true) + require.NoError(t, err) + require.Equal(t, "hello", gjson.GetBytes(updated, "input.0.text").String()) + require.Equal(t, "world", gjson.GetBytes(updated, "input.1.text").String()) + }) + + t.Run("preserve_empty_array_not_null", func(t *testing.T) { + original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`) + updated, err := setOpenAIWSPayloadInputSequence(original, nil, true) + require.NoError(t, err) + require.True(t, gjson.GetBytes(updated, "input").IsArray()) + require.Len(t, gjson.GetBytes(updated, "input").Array(), 0) + require.False(t, gjson.GetBytes(updated, "input").Type == gjson.Null) + }) +} + +func TestCloneOpenAIWSRawMessages(t *testing.T) { + t.Parallel() + + t.Run("nil_slice", func(t *testing.T) { + cloned := cloneOpenAIWSRawMessages(nil) + require.Nil(t, cloned) + }) + + t.Run("empty_slice", func(t *testing.T) { + items := make([]json.RawMessage, 0) + cloned := cloneOpenAIWSRawMessages(items) + require.NotNil(t, cloned) + require.Len(t, cloned, 0) + }) +} diff --git a/backend/internal/service/openai_ws_forwarder_retry_payload_test.go b/backend/internal/service/openai_ws_forwarder_retry_payload_test.go new file mode 100644 index 00000000..0ea7e1c7 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_retry_payload_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestApplyOpenAIWSRetryPayloadStrategy_KeepPromptCacheKey(t *testing.T) { + payload := map[string]any{ + "model": "gpt-5.3-codex", + "prompt_cache_key": "pcache_123", + "include": []any{"reasoning.encrypted_content"}, + "text": map[string]any{ + "verbosity": "low", + }, + "tools": []any{map[string]any{"type": "function"}}, + } + + strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 3) + require.Equal(t, "trim_optional_fields", strategy) + require.Contains(t, removed, "include") + require.NotContains(t, removed, "prompt_cache_key") + require.Equal(t, "pcache_123", payload["prompt_cache_key"]) + require.NotContains(t, payload, "include") + require.Contains(t, payload, "text") +} + +func TestApplyOpenAIWSRetryPayloadStrategy_AttemptSixKeepsSemanticFields(t *testing.T) { + payload := map[string]any{ + "prompt_cache_key": "pcache_456", + "instructions": "long instructions", + "tools": []any{map[string]any{"type": "function"}}, + "parallel_tool_calls": true, + "tool_choice": "auto", + "include": []any{"reasoning.encrypted_content"}, + "text": map[string]any{"verbosity": "high"}, + } + + strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 6) + require.Equal(t, "trim_optional_fields", strategy) + require.Contains(t, removed, "include") + require.NotContains(t, removed, "prompt_cache_key") + require.Equal(t, "pcache_456", payload["prompt_cache_key"]) + require.Contains(t, payload, "instructions") + require.Contains(t, payload, "tools") + require.Contains(t, payload, "tool_choice") + require.Contains(t, payload, "parallel_tool_calls") + require.Contains(t, payload, "text") +} diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go new file mode 100644 index 00000000..592801f6 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -0,0 +1,1306 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) { + gin.SetMode(gin.TestMode) + + type receivedPayload struct { + Type string + PreviousResponseID string + StreamExists bool + Stream bool + } + receivedCh := make(chan receivedPayload, 1) + + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + requestJSON := requestToJSONString(request) + receivedCh <- receivedPayload{ + Type: strings.TrimSpace(gjson.Get(requestJSON, "type").String()), + PreviousResponseID: strings.TrimSpace(gjson.Get(requestJSON, "previous_response_id").String()), + StreamExists: gjson.Get(requestJSON, "stream").Exists(), + Stream: gjson.Get(requestJSON, "stream").Bool(), + } + + if err := conn.WriteJSON(map[string]any{ + "type": "response.created", + "response": map[string]any{ + "id": "resp_new_1", + "model": "gpt-5.1", + }, + }); err != nil { + t.Errorf("write response.created failed: %v", err) + return + } + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_new_1", + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 12, + "output_tokens": 7, + "input_tokens_details": map[string]any{ + "cached_tokens": 3, + }, + }, + }, + }); err != nil { + t.Errorf("write response.completed failed: %v", err) + return + } + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") + groupID := int64(1001) + c.Set("api_key", &APIKey{GroupID: &groupID}) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cache := &stubGatewayCache{} + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + cache: cache, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 9, + Name: "openai-ws", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_1","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) + require.Equal(t, "resp_new_1", result.RequestID) + require.True(t, result.OpenAIWSMode) + require.False(t, gjson.GetBytes(upstream.lastBody, "model").Exists(), "WSv2 成功时不应回落 HTTP 上游") + + received := <-receivedCh + require.Equal(t, "response.create", received.Type) + require.Equal(t, "resp_prev_1", received.PreviousResponseID) + require.True(t, received.StreamExists, "WS 请求应携带 stream 字段") + require.False(t, received.Stream, "应保持客户端 stream=false 的原始语义") + + store := svc.getOpenAIWSStateStore() + mappedAccountID, getErr := store.GetResponseAccount(context.Background(), groupID, "resp_new_1") + require.NoError(t, getErr) + require.Equal(t, account.ID, mappedAccountID) + connID, ok := store.GetResponseConn("resp_new_1") + require.True(t, ok) + require.NotEmpty(t, connID) + + responseBody := rec.Body.Bytes() + require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String()) +} + +func requestToJSONString(payload map[string]any) string { + if len(payload) == 0 { + return "{}" + } + b, err := json.Marshal(payload) + if err != nil { + return "{}" + } + return string(b) +} + +func TestLogOpenAIWSBindResponseAccountWarn(t *testing.T) { + require.NotPanics(t, func() { + logOpenAIWSBindResponseAccountWarn(1, 2, "resp_ok", nil) + }) + require.NotPanics(t, func() { + logOpenAIWSBindResponseAccountWarn(1, 2, "resp_err", errors.New("bind failed")) + }) +} + +func TestOpenAIGatewayService_Forward_WSv2_RewriteModelAndToolCallsOnCompletedEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + groupID := int64(3001) + c.Set("api_key", &APIKey{GroupID: &groupID}) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_model_tool_1","model":"gpt-5.1","tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}],"usage":{"input_tokens":2,"output_tokens":1}},"tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}]}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 1301, + Name: "openai-rewrite", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "model_mapping": map[string]any{ + "custom-original-model": "gpt-5.1", + }, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"custom-original-model","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_model_tool_1", result.RequestID) + require.Equal(t, "custom-original-model", gjson.GetBytes(rec.Body.Bytes(), "model").String(), "响应模型应回写为原始请求模型") + require.Equal(t, "edit", gjson.GetBytes(rec.Body.Bytes(), "tool_calls.0.function.name").String(), "工具名称应被修正为 OpenCode 规范") +} + +func TestOpenAIWSPayloadString_OnlyAcceptsStringValues(t *testing.T) { + payload := map[string]any{ + "type": nil, + "model": 123, + "prompt_cache_key": " cache-key ", + "previous_response_id": []byte(" resp_1 "), + } + + require.Equal(t, "", openAIWSPayloadString(payload, "type")) + require.Equal(t, "", openAIWSPayloadString(payload, "model")) + require.Equal(t, "cache-key", openAIWSPayloadString(payload, "prompt_cache_key")) + require.Equal(t, "resp_1", openAIWSPayloadString(payload, "previous_response_id")) +} + +func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) { + gin.SetMode(gin.TestMode) + + var upgradeCount atomic.Int64 + var sequence atomic.Int64 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgradeCount.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + for { + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + return + } + idx := sequence.Add(1) + responseID := "resp_reuse_" + strconv.FormatInt(idx, 10) + if err := conn.WriteJSON(map[string]any{ + "type": "response.created", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + }, + }); err != nil { + return + } + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 2, + "output_tokens": 1, + }, + }, + }); err != nil { + return + } + } + })) + defer wsServer.Close() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + account := &Account{ + ID: 19, + Name: "openai-ws", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + for i := 0; i < 2; i++ { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + groupID := int64(2001) + c.Set("api_key", &APIKey{GroupID: &groupID}) + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_reuse","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_")) + } + + require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链") + metrics := svc.SnapshotOpenAIWSPoolMetrics() + require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1)) + require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1)) +} + +func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + c.Request.Header.Set("session_id", "sess-oauth-1") + c.Request.Header.Set("conversation_id", "conv-oauth-1") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.AllowStoreRecovery = false + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_oauth_1","model":"gpt-5.1","usage":{"input_tokens":3,"output_tokens":2}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + account := &Account{ + ID: 29, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token-1", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"store":true,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_oauth_1", result.RequestID) + + require.NotNil(t, captureConn.lastWrite) + requestJSON := requestToJSONString(captureConn.lastWrite) + require.True(t, gjson.Get(requestJSON, "store").Exists(), "OAuth WSv2 应显式写入 store 字段") + require.False(t, gjson.Get(requestJSON, "store").Bool(), "默认策略应将 OAuth store 置为 false") + require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段") + require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true") + require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta")) + require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id")) + require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id")) +} + +func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_prompt_cache_key","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + account := &Account{ + ID: 31, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token-1", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":true,"prompt_cache_key":"pcache_123","input":[{"type":"input_text","text":"hi"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_prompt_cache_key", result.RequestID) + + require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id")) + require.Empty(t, captureDialer.lastHeaders.Get("conversation_id")) + require.NotNil(t, captureConn.lastWrite) + require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsockets = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 39, + Name: "openai-ws-v1", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1/responses", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_v1","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "ws v1") + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "WSv1") + require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求") +} + +func TestOpenAIGatewayService_Forward_WSv2_TurnStateAndMetadataReplayOnReconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + + var connIndex atomic.Int64 + headersCh := make(chan http.Header, 4) + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + idx := connIndex.Add(1) + headersCh <- cloneHeader(r.Header) + + respHeader := http.Header{} + if idx == 1 { + respHeader.Set("x-codex-turn-state", "turn_state_first") + } + conn, err := upgrader.Upgrade(w, r, respHeader) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + responseID := "resp_turn_" + strconv.FormatInt(idx, 10) + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 2, + "output_tokens": 1, + }, + }, + }); err != nil { + t.Errorf("write response.completed failed: %v", err) + return + } + })) + defer wsServer.Close() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 49, + Name: "openai-turn-state", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + reqBody := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + rec1 := httptest.NewRecorder() + c1, _ := gin.CreateTestContext(rec1) + c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c1.Request.Header.Set("session_id", "session_turn_state") + c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_1") + result1, err := svc.Forward(context.Background(), c1, account, reqBody) + require.NoError(t, err) + require.NotNil(t, result1) + + sessionHash := svc.GenerateSessionHash(c1, reqBody) + store := svc.getOpenAIWSStateStore() + turnState, ok := store.GetSessionTurnState(0, sessionHash) + require.True(t, ok) + require.Equal(t, "turn_state_first", turnState) + + // 主动淘汰连接,模拟下一次请求发生重连。 + connID, hasConn := store.GetResponseConn(result1.RequestID) + require.True(t, hasConn) + svc.getOpenAIWSConnPool().evictConn(account.ID, connID) + + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c2.Request.Header.Set("session_id", "session_turn_state") + c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_2") + result2, err := svc.Forward(context.Background(), c2, account, reqBody) + require.NoError(t, err) + require.NotNil(t, result2) + + firstHandshakeHeaders := <-headersCh + secondHandshakeHeaders := <-headersCh + require.Equal(t, "turn_meta_1", firstHandshakeHeaders.Get("X-Codex-Turn-Metadata")) + require.Equal(t, "turn_meta_2", secondHandshakeHeaders.Get("X-Codex-Turn-Metadata")) + require.Equal(t, "turn_state_first", secondHandshakeHeaders.Get("X-Codex-Turn-State")) +} + +func TestOpenAIGatewayService_Forward_WSv2_GeneratePrewarm(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("session_id", "session-prewarm") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_prewarm_1","model":"gpt-5.1","usage":{"input_tokens":0,"output_tokens":0}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_main_1","model":"gpt-5.1","usage":{"input_tokens":4,"output_tokens":2}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 59, + Name: "openai-prewarm", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_main_1", result.RequestID) + + require.Len(t, captureConn.writes, 2, "开启 generate=false 预热后应发送两次 WS 请求") + firstWrite := requestToJSONString(captureConn.writes[0]) + secondWrite := requestToJSONString(captureConn.writes[1]) + require.True(t, gjson.Get(firstWrite, "generate").Exists()) + require.False(t, gjson.Get(firstWrite, "generate").Bool()) + require.False(t, gjson.Get(secondWrite, "generate").Exists()) +} + +func TestOpenAIGatewayService_PrewarmReadHonorsParentContext(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + svc := &OpenAIGatewayService{ + cfg: cfg, + toolCorrector: NewCodexToolCorrector(), + } + account := &Account{ + ID: 601, + Name: "openai-prewarm-timeout", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + } + conn := newOpenAIWSConn("prewarm_ctx_conn", account.ID, &openAIWSBlockingConn{ + readDelay: 200 * time.Millisecond, + }, nil) + lease := &openAIWSConnLease{ + accountID: account.ID, + conn: conn, + } + payload := map[string]any{ + "type": "response.create", + "model": "gpt-5.1", + } + + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) + defer cancel() + start := time.Now() + err := svc.performOpenAIWSGeneratePrewarm( + ctx, + lease, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + payload, + "", + map[string]any{"model": "gpt-5.1"}, + account, + nil, + 0, + ) + elapsed := time.Since(start) + require.Error(t, err) + require.Contains(t, err.Error(), "prewarm_read_event") + require.Less(t, elapsed, 180*time.Millisecond, "预热读取应受父 context 取消控制,不应阻塞到 read_timeout") +} + +func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_meta_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_meta_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 69, + Name: "openai-turn-metadata", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + + rec1 := httptest.NewRecorder() + c1, _ := gin.CreateTestContext(rec1) + c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c1.Request.Header.Set("session_id", "session-metadata-reuse") + c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_1") + result1, err := svc.Forward(context.Background(), c1, account, body) + require.NoError(t, err) + require.NotNil(t, result1) + require.Equal(t, "resp_meta_1", result1.RequestID) + + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c2.Request.Header.Set("session_id", "session-metadata-reuse") + c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_2") + result2, err := svc.Forward(context.Background(), c2, account, body) + require.NoError(t, err) + require.NotNil(t, result2) + require.Equal(t, "resp_meta_2", result2.RequestID) + + require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接") + require.Len(t, captureConn.writes, 2) + + firstWrite := requestToJSONString(captureConn.writes[0]) + secondWrite := requestToJSONString(captureConn.writes[1]) + require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String()) + require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String()) +} + +func TestOpenAIGatewayService_Forward_WSv2StoreFalseSessionConnIsolation(t *testing.T) { + gin.SetMode(gin.TestMode) + + var upgradeCount atomic.Int64 + var sequence atomic.Int64 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgradeCount.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + for { + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + return + } + responseID := "resp_store_false_" + strconv.FormatInt(sequence.Add(1), 10) + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + }, + }, + }); err != nil { + return + } + } + })) + defer wsServer.Close() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4 + cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 79, + Name: "openai-store-false", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + + rec1 := httptest.NewRecorder() + c1, _ := gin.CreateTestContext(rec1) + c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c1.Request.Header.Set("session_id", "session_store_false_a") + result1, err := svc.Forward(context.Background(), c1, account, body) + require.NoError(t, err) + require.NotNil(t, result1) + require.Equal(t, int64(1), upgradeCount.Load()) + + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c2.Request.Header.Set("session_id", "session_store_false_a") + result2, err := svc.Forward(context.Background(), c2, account, body) + require.NoError(t, err) + require.NotNil(t, result2) + require.Equal(t, int64(1), upgradeCount.Load(), "同一 session(store=false) 应复用同一 WS 连接") + + rec3 := httptest.NewRecorder() + c3, _ := gin.CreateTestContext(rec3) + c3.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c3.Request.Header.Set("session_id", "session_store_false_b") + result3, err := svc.Forward(context.Background(), c3, account, body) + require.NoError(t, err) + require.NotNil(t, result3) + require.Equal(t, int64(2), upgradeCount.Load(), "不同 session(store=false) 应隔离连接,避免续链状态互相覆盖") +} + +func TestOpenAIGatewayService_Forward_WSv2StoreFalseDisableForceNewConnAllowsReuse(t *testing.T) { + gin.SetMode(gin.TestMode) + + var upgradeCount atomic.Int64 + var sequence atomic.Int64 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgradeCount.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + for { + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + return + } + responseID := "resp_store_false_reuse_" + strconv.FormatInt(sequence.Add(1), 10) + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + }, + }, + }); err != nil { + return + } + } + })) + defer wsServer.Close() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 80, + Name: "openai-store-false-reuse", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + + rec1 := httptest.NewRecorder() + c1, _ := gin.CreateTestContext(rec1) + c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c1.Request.Header.Set("session_id", "session_store_false_reuse_a") + result1, err := svc.Forward(context.Background(), c1, account, body) + require.NoError(t, err) + require.NotNil(t, result1) + require.Equal(t, int64(1), upgradeCount.Load()) + + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c2.Request.Header.Set("session_id", "session_store_false_reuse_b") + result2, err := svc.Forward(context.Background(), c2, account, body) + require.NoError(t, err) + require.NotNil(t, result2) + require.Equal(t, int64(1), upgradeCount.Load(), "关闭强制新连后,不同 session(store=false) 可复用连接") +} + +func TestOpenAIGatewayService_Forward_WSv2ReadTimeoutAppliesPerRead(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 1 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + readDelays: []time.Duration{ + 700 * time.Millisecond, + 700 * time.Millisecond, + }, + events: [][]byte{ + []byte(`{"type":"response.created","response":{"id":"resp_timeout_ok","model":"gpt-5.1"}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_timeout_ok","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 81, + Name: "openai-read-timeout", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_timeout_ok", result.RequestID) + require.Nil(t, upstream.lastReq, "每次 Read 都应独立应用超时;总时长超过 read_timeout 不应误回退 HTTP") +} + +type openAIWSCaptureDialer struct { + mu sync.Mutex + conn *openAIWSCaptureConn + lastHeaders http.Header + handshake http.Header + dialCount int +} + +func (d *openAIWSCaptureDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = proxyURL + d.mu.Lock() + d.lastHeaders = cloneHeader(headers) + d.dialCount++ + respHeaders := cloneHeader(d.handshake) + d.mu.Unlock() + return d.conn, 0, respHeaders, nil +} + +func (d *openAIWSCaptureDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSCaptureConn struct { + mu sync.Mutex + readDelays []time.Duration + events [][]byte + lastWrite map[string]any + writes []map[string]any + closed bool +} + +func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return errOpenAIWSConnClosed + } + switch payload := value.(type) { + case map[string]any: + c.lastWrite = cloneMapStringAny(payload) + c.writes = append(c.writes, cloneMapStringAny(payload)) + case json.RawMessage: + var parsed map[string]any + if err := json.Unmarshal(payload, &parsed); err == nil { + c.lastWrite = cloneMapStringAny(parsed) + c.writes = append(c.writes, cloneMapStringAny(parsed)) + } + case []byte: + var parsed map[string]any + if err := json.Unmarshal(payload, &parsed); err == nil { + c.lastWrite = cloneMapStringAny(parsed) + c.writes = append(c.writes, cloneMapStringAny(parsed)) + } + } + return nil +} + +func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) { + if ctx == nil { + ctx = context.Background() + } + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, errOpenAIWSConnClosed + } + if len(c.events) == 0 { + c.mu.Unlock() + return nil, io.EOF + } + delay := time.Duration(0) + if len(c.readDelays) > 0 { + delay = c.readDelays[0] + c.readDelays = c.readDelays[1:] + } + event := c.events[0] + c.events = c.events[1:] + c.mu.Unlock() + if delay > 0 { + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + } + } + return event, nil +} + +func (c *openAIWSCaptureConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSCaptureConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func cloneMapStringAny(src map[string]any) map[string]any { + if src == nil { + return nil + } + dst := make(map[string]any, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} diff --git a/backend/internal/service/openai_ws_pool.go b/backend/internal/service/openai_ws_pool.go new file mode 100644 index 00000000..db6a96a7 --- /dev/null +++ b/backend/internal/service/openai_ws_pool.go @@ -0,0 +1,1706 @@ +package service + +import ( + "context" + "errors" + "fmt" + "math" + "net/http" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "golang.org/x/sync/errgroup" +) + +const ( + openAIWSConnMaxAge = 60 * time.Minute + openAIWSConnHealthCheckIdle = 90 * time.Second + openAIWSConnHealthCheckTO = 2 * time.Second + openAIWSConnPrewarmExtraDelay = 2 * time.Second + openAIWSAcquireCleanupInterval = 3 * time.Second + openAIWSBackgroundPingInterval = 30 * time.Second + openAIWSBackgroundSweepTicker = 30 * time.Second + + openAIWSPrewarmFailureWindow = 30 * time.Second + openAIWSPrewarmFailureSuppress = 2 +) + +var ( + errOpenAIWSConnClosed = errors.New("openai ws connection closed") + errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full") + errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable") +) + +type openAIWSDialError struct { + StatusCode int + ResponseHeaders http.Header + Err error +} + +func (e *openAIWSDialError) Error() string { + if e == nil { + return "" + } + if e.StatusCode > 0 { + return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err) + } + return fmt.Sprintf("openai ws dial failed: %v", e.Err) +} + +func (e *openAIWSDialError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +type openAIWSAcquireRequest struct { + Account *Account + WSURL string + Headers http.Header + ProxyURL string + PreferredConnID string + // ForceNewConn: 强制本次获取新连接(避免复用导致连接内续链状态互相污染)。 + ForceNewConn bool + // ForcePreferredConn: 强制本次只使用 PreferredConnID,禁止漂移到其它连接。 + ForcePreferredConn bool +} + +type openAIWSConnLease struct { + pool *openAIWSConnPool + accountID int64 + conn *openAIWSConn + queueWait time.Duration + connPick time.Duration + reused bool + released atomic.Bool +} + +func (l *openAIWSConnLease) activeConn() (*openAIWSConn, error) { + if l == nil || l.conn == nil { + return nil, errOpenAIWSConnClosed + } + if l.released.Load() { + return nil, errOpenAIWSConnClosed + } + return l.conn, nil +} + +func (l *openAIWSConnLease) ConnID() string { + if l == nil || l.conn == nil { + return "" + } + return l.conn.id +} + +func (l *openAIWSConnLease) QueueWaitDuration() time.Duration { + if l == nil { + return 0 + } + return l.queueWait +} + +func (l *openAIWSConnLease) ConnPickDuration() time.Duration { + if l == nil { + return 0 + } + return l.connPick +} + +func (l *openAIWSConnLease) Reused() bool { + if l == nil { + return false + } + return l.reused +} + +func (l *openAIWSConnLease) HandshakeHeader(name string) string { + if l == nil || l.conn == nil { + return "" + } + return l.conn.handshakeHeader(name) +} + +func (l *openAIWSConnLease) IsPrewarmed() bool { + if l == nil || l.conn == nil { + return false + } + return l.conn.isPrewarmed() +} + +func (l *openAIWSConnLease) MarkPrewarmed() { + if l == nil || l.conn == nil { + return + } + l.conn.markPrewarmed() +} + +func (l *openAIWSConnLease) WriteJSON(value any, timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + return conn.writeJSONWithTimeout(context.Background(), value, timeout) +} + +func (l *openAIWSConnLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + return conn.writeJSONWithTimeout(ctx, value, timeout) +} + +func (l *openAIWSConnLease) WriteJSONContext(ctx context.Context, value any) error { + conn, err := l.activeConn() + if err != nil { + return err + } + return conn.writeJSON(value, ctx) +} + +func (l *openAIWSConnLease) ReadMessage(timeout time.Duration) ([]byte, error) { + conn, err := l.activeConn() + if err != nil { + return nil, err + } + return conn.readMessageWithTimeout(timeout) +} + +func (l *openAIWSConnLease) ReadMessageContext(ctx context.Context) ([]byte, error) { + conn, err := l.activeConn() + if err != nil { + return nil, err + } + return conn.readMessage(ctx) +} + +func (l *openAIWSConnLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) { + conn, err := l.activeConn() + if err != nil { + return nil, err + } + return conn.readMessageWithContextTimeout(ctx, timeout) +} + +func (l *openAIWSConnLease) PingWithTimeout(timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + return conn.pingWithTimeout(timeout) +} + +func (l *openAIWSConnLease) MarkBroken() { + if l == nil || l.pool == nil || l.conn == nil || l.released.Load() { + return + } + l.pool.evictConn(l.accountID, l.conn.id) +} + +func (l *openAIWSConnLease) Release() { + if l == nil || l.conn == nil { + return + } + if !l.released.CompareAndSwap(false, true) { + return + } + l.conn.release() +} + +type openAIWSConn struct { + id string + ws openAIWSClientConn + + handshakeHeaders http.Header + + leaseCh chan struct{} + closedCh chan struct{} + closeOnce sync.Once + + readMu sync.Mutex + writeMu sync.Mutex + + waiters atomic.Int32 + createdAtNano atomic.Int64 + lastUsedNano atomic.Int64 + prewarmed atomic.Bool +} + +func newOpenAIWSConn(id string, _ int64, ws openAIWSClientConn, handshakeHeaders http.Header) *openAIWSConn { + now := time.Now() + conn := &openAIWSConn{ + id: id, + ws: ws, + handshakeHeaders: cloneHeader(handshakeHeaders), + leaseCh: make(chan struct{}, 1), + closedCh: make(chan struct{}), + } + conn.leaseCh <- struct{}{} + conn.createdAtNano.Store(now.UnixNano()) + conn.lastUsedNano.Store(now.UnixNano()) + return conn +} + +func (c *openAIWSConn) tryAcquire() bool { + if c == nil { + return false + } + select { + case <-c.closedCh: + return false + default: + } + select { + case <-c.leaseCh: + select { + case <-c.closedCh: + c.release() + return false + default: + } + return true + default: + return false + } +} + +func (c *openAIWSConn) acquire(ctx context.Context) error { + if c == nil { + return errOpenAIWSConnClosed + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.closedCh: + return errOpenAIWSConnClosed + case <-c.leaseCh: + select { + case <-c.closedCh: + c.release() + return errOpenAIWSConnClosed + default: + } + return nil + } + } +} + +func (c *openAIWSConn) release() { + if c == nil { + return + } + select { + case c.leaseCh <- struct{}{}: + default: + } + c.touch() +} + +func (c *openAIWSConn) close() { + if c == nil { + return + } + c.closeOnce.Do(func() { + close(c.closedCh) + if c.ws != nil { + _ = c.ws.Close() + } + select { + case c.leaseCh <- struct{}{}: + default: + } + }) +} + +func (c *openAIWSConn) writeJSONWithTimeout(parent context.Context, value any, timeout time.Duration) error { + if c == nil { + return errOpenAIWSConnClosed + } + select { + case <-c.closedCh: + return errOpenAIWSConnClosed + default: + } + + writeCtx := parent + if writeCtx == nil { + writeCtx = context.Background() + } + if timeout <= 0 { + return c.writeJSON(value, writeCtx) + } + var cancel context.CancelFunc + writeCtx, cancel = context.WithTimeout(writeCtx, timeout) + defer cancel() + return c.writeJSON(value, writeCtx) +} + +func (c *openAIWSConn) writeJSON(value any, writeCtx context.Context) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + if c.ws == nil { + return errOpenAIWSConnClosed + } + if writeCtx == nil { + writeCtx = context.Background() + } + if err := c.ws.WriteJSON(writeCtx, value); err != nil { + return err + } + c.touch() + return nil +} + +func (c *openAIWSConn) readMessageWithTimeout(timeout time.Duration) ([]byte, error) { + return c.readMessageWithContextTimeout(context.Background(), timeout) +} + +func (c *openAIWSConn) readMessageWithContextTimeout(parent context.Context, timeout time.Duration) ([]byte, error) { + if c == nil { + return nil, errOpenAIWSConnClosed + } + select { + case <-c.closedCh: + return nil, errOpenAIWSConnClosed + default: + } + + if parent == nil { + parent = context.Background() + } + if timeout <= 0 { + return c.readMessage(parent) + } + readCtx, cancel := context.WithTimeout(parent, timeout) + defer cancel() + return c.readMessage(readCtx) +} + +func (c *openAIWSConn) readMessage(readCtx context.Context) ([]byte, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + if c.ws == nil { + return nil, errOpenAIWSConnClosed + } + if readCtx == nil { + readCtx = context.Background() + } + payload, err := c.ws.ReadMessage(readCtx) + if err != nil { + return nil, err + } + c.touch() + return payload, nil +} + +func (c *openAIWSConn) pingWithTimeout(timeout time.Duration) error { + if c == nil { + return errOpenAIWSConnClosed + } + select { + case <-c.closedCh: + return errOpenAIWSConnClosed + default: + } + + c.writeMu.Lock() + defer c.writeMu.Unlock() + if c.ws == nil { + return errOpenAIWSConnClosed + } + if timeout <= 0 { + timeout = openAIWSConnHealthCheckTO + } + pingCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + if err := c.ws.Ping(pingCtx); err != nil { + return err + } + return nil +} + +func (c *openAIWSConn) touch() { + if c == nil { + return + } + c.lastUsedNano.Store(time.Now().UnixNano()) +} + +func (c *openAIWSConn) createdAt() time.Time { + if c == nil { + return time.Time{} + } + nano := c.createdAtNano.Load() + if nano <= 0 { + return time.Time{} + } + return time.Unix(0, nano) +} + +func (c *openAIWSConn) lastUsedAt() time.Time { + if c == nil { + return time.Time{} + } + nano := c.lastUsedNano.Load() + if nano <= 0 { + return time.Time{} + } + return time.Unix(0, nano) +} + +func (c *openAIWSConn) idleDuration(now time.Time) time.Duration { + if c == nil { + return 0 + } + last := c.lastUsedAt() + if last.IsZero() { + return 0 + } + return now.Sub(last) +} + +func (c *openAIWSConn) age(now time.Time) time.Duration { + if c == nil { + return 0 + } + created := c.createdAt() + if created.IsZero() { + return 0 + } + return now.Sub(created) +} + +func (c *openAIWSConn) isLeased() bool { + if c == nil { + return false + } + return len(c.leaseCh) == 0 +} + +func (c *openAIWSConn) handshakeHeader(name string) string { + if c == nil || c.handshakeHeaders == nil { + return "" + } + return strings.TrimSpace(c.handshakeHeaders.Get(strings.TrimSpace(name))) +} + +func (c *openAIWSConn) isPrewarmed() bool { + if c == nil { + return false + } + return c.prewarmed.Load() +} + +func (c *openAIWSConn) markPrewarmed() { + if c == nil { + return + } + c.prewarmed.Store(true) +} + +type openAIWSAccountPool struct { + mu sync.Mutex + conns map[string]*openAIWSConn + pinnedConns map[string]int + creating int + lastCleanupAt time.Time + lastAcquire *openAIWSAcquireRequest + prewarmActive bool + prewarmUntil time.Time + prewarmFails int + prewarmFailAt time.Time +} + +type OpenAIWSPoolMetricsSnapshot struct { + AcquireTotal int64 + AcquireReuseTotal int64 + AcquireCreateTotal int64 + AcquireQueueWaitTotal int64 + AcquireQueueWaitMsTotal int64 + ConnPickTotal int64 + ConnPickMsTotal int64 + ScaleUpTotal int64 + ScaleDownTotal int64 +} + +type openAIWSPoolMetrics struct { + acquireTotal atomic.Int64 + acquireReuseTotal atomic.Int64 + acquireCreateTotal atomic.Int64 + acquireQueueWaitTotal atomic.Int64 + acquireQueueWaitMs atomic.Int64 + connPickTotal atomic.Int64 + connPickMs atomic.Int64 + scaleUpTotal atomic.Int64 + scaleDownTotal atomic.Int64 +} + +type openAIWSConnPool struct { + cfg *config.Config + // 通过接口解耦底层 WS 客户端实现,默认使用 coder/websocket。 + clientDialer openAIWSClientDialer + + accounts sync.Map // key: int64(accountID), value: *openAIWSAccountPool + seq atomic.Uint64 + + metrics openAIWSPoolMetrics + + workerStopCh chan struct{} + workerWg sync.WaitGroup + closeOnce sync.Once +} + +func newOpenAIWSConnPool(cfg *config.Config) *openAIWSConnPool { + pool := &openAIWSConnPool{ + cfg: cfg, + clientDialer: newDefaultOpenAIWSClientDialer(), + workerStopCh: make(chan struct{}), + } + pool.startBackgroundWorkers() + return pool +} + +func (p *openAIWSConnPool) SnapshotMetrics() OpenAIWSPoolMetricsSnapshot { + if p == nil { + return OpenAIWSPoolMetricsSnapshot{} + } + return OpenAIWSPoolMetricsSnapshot{ + AcquireTotal: p.metrics.acquireTotal.Load(), + AcquireReuseTotal: p.metrics.acquireReuseTotal.Load(), + AcquireCreateTotal: p.metrics.acquireCreateTotal.Load(), + AcquireQueueWaitTotal: p.metrics.acquireQueueWaitTotal.Load(), + AcquireQueueWaitMsTotal: p.metrics.acquireQueueWaitMs.Load(), + ConnPickTotal: p.metrics.connPickTotal.Load(), + ConnPickMsTotal: p.metrics.connPickMs.Load(), + ScaleUpTotal: p.metrics.scaleUpTotal.Load(), + ScaleDownTotal: p.metrics.scaleDownTotal.Load(), + } +} + +func (p *openAIWSConnPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot { + if p == nil { + return OpenAIWSTransportMetricsSnapshot{} + } + if dialer, ok := p.clientDialer.(openAIWSTransportMetricsDialer); ok { + return dialer.SnapshotTransportMetrics() + } + return OpenAIWSTransportMetricsSnapshot{} +} + +func (p *openAIWSConnPool) setClientDialerForTest(dialer openAIWSClientDialer) { + if p == nil || dialer == nil { + return + } + p.clientDialer = dialer +} + +// Close 停止后台 worker 并关闭所有空闲连接,应在优雅关闭时调用。 +func (p *openAIWSConnPool) Close() { + if p == nil { + return + } + p.closeOnce.Do(func() { + if p.workerStopCh != nil { + close(p.workerStopCh) + } + p.workerWg.Wait() + // 遍历所有账户池,关闭全部空闲连接。 + p.accounts.Range(func(key, value any) bool { + ap, ok := value.(*openAIWSAccountPool) + if !ok || ap == nil { + return true + } + ap.mu.Lock() + for _, conn := range ap.conns { + if conn != nil && !conn.isLeased() { + conn.close() + } + } + ap.mu.Unlock() + return true + }) + }) +} + +func (p *openAIWSConnPool) startBackgroundWorkers() { + if p == nil || p.workerStopCh == nil { + return + } + p.workerWg.Add(2) + go func() { + defer p.workerWg.Done() + p.runBackgroundPingWorker() + }() + go func() { + defer p.workerWg.Done() + p.runBackgroundCleanupWorker() + }() +} + +type openAIWSIdlePingCandidate struct { + accountID int64 + conn *openAIWSConn +} + +func (p *openAIWSConnPool) runBackgroundPingWorker() { + if p == nil { + return + } + ticker := time.NewTicker(openAIWSBackgroundPingInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + p.runBackgroundPingSweep() + case <-p.workerStopCh: + return + } + } +} + +func (p *openAIWSConnPool) runBackgroundPingSweep() { + if p == nil { + return + } + candidates := p.snapshotIdleConnsForPing() + var g errgroup.Group + g.SetLimit(10) + for _, item := range candidates { + item := item + if item.conn == nil || item.conn.isLeased() || item.conn.waiters.Load() > 0 { + continue + } + g.Go(func() error { + if err := item.conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + p.evictConn(item.accountID, item.conn.id) + } + return nil + }) + } + _ = g.Wait() +} + +func (p *openAIWSConnPool) snapshotIdleConnsForPing() []openAIWSIdlePingCandidate { + if p == nil { + return nil + } + candidates := make([]openAIWSIdlePingCandidate, 0) + p.accounts.Range(func(key, value any) bool { + accountID, ok := key.(int64) + if !ok || accountID <= 0 { + return true + } + ap, ok := value.(*openAIWSAccountPool) + if !ok || ap == nil { + return true + } + ap.mu.Lock() + for _, conn := range ap.conns { + if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 { + continue + } + candidates = append(candidates, openAIWSIdlePingCandidate{ + accountID: accountID, + conn: conn, + }) + } + ap.mu.Unlock() + return true + }) + return candidates +} + +func (p *openAIWSConnPool) runBackgroundCleanupWorker() { + if p == nil { + return + } + ticker := time.NewTicker(openAIWSBackgroundSweepTicker) + defer ticker.Stop() + for { + select { + case <-ticker.C: + p.runBackgroundCleanupSweep(time.Now()) + case <-p.workerStopCh: + return + } + } +} + +func (p *openAIWSConnPool) runBackgroundCleanupSweep(now time.Time) { + if p == nil { + return + } + type cleanupResult struct { + evicted []*openAIWSConn + } + results := make([]cleanupResult, 0) + p.accounts.Range(func(_ any, value any) bool { + ap, ok := value.(*openAIWSAccountPool) + if !ok || ap == nil { + return true + } + maxConns := p.maxConnsHardCap() + ap.mu.Lock() + if ap.lastAcquire != nil && ap.lastAcquire.Account != nil { + maxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account) + } + evicted := p.cleanupAccountLocked(ap, now, maxConns) + ap.lastCleanupAt = now + ap.mu.Unlock() + if len(evicted) > 0 { + results = append(results, cleanupResult{evicted: evicted}) + } + return true + }) + for _, result := range results { + closeOpenAIWSConns(result.evicted) + } +} + +func (p *openAIWSConnPool) Acquire(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConnLease, error) { + if p != nil { + p.metrics.acquireTotal.Add(1) + } + return p.acquire(ctx, cloneOpenAIWSAcquireRequest(req), 0) +} + +func (p *openAIWSConnPool) acquire(ctx context.Context, req openAIWSAcquireRequest, retry int) (*openAIWSConnLease, error) { + if p == nil || req.Account == nil || req.Account.ID <= 0 { + return nil, errors.New("invalid ws acquire request") + } + if stringsTrim(req.WSURL) == "" { + return nil, errors.New("ws url is empty") + } + + accountID := req.Account.ID + effectiveMaxConns := p.effectiveMaxConnsByAccount(req.Account) + if effectiveMaxConns <= 0 { + return nil, errOpenAIWSConnQueueFull + } + var evicted []*openAIWSConn + ap := p.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.lastAcquire = cloneOpenAIWSAcquireRequestPtr(&req) + now := time.Now() + if ap.lastCleanupAt.IsZero() || now.Sub(ap.lastCleanupAt) >= openAIWSAcquireCleanupInterval { + evicted = p.cleanupAccountLocked(ap, now, effectiveMaxConns) + ap.lastCleanupAt = now + } + pickStartedAt := time.Now() + allowReuse := !req.ForceNewConn + preferredConnID := stringsTrim(req.PreferredConnID) + forcePreferredConn := allowReuse && req.ForcePreferredConn + + if allowReuse { + if forcePreferredConn { + if preferredConnID == "" { + p.recordConnPickDuration(time.Since(pickStartedAt)) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSPreferredConnUnavailable + } + preferredConn, ok := ap.conns[preferredConnID] + if !ok || preferredConn == nil { + p.recordConnPickDuration(time.Since(pickStartedAt)) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSPreferredConnUnavailable + } + if preferredConn.tryAcquire() { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + if p.shouldHealthCheckConn(preferredConn) { + if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + preferredConn.close() + p.evictConn(accountID, preferredConn.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + lease := &openAIWSConnLease{ + pool: p, + accountID: accountID, + conn: preferredConn, + connPick: connPick, + reused: true, + } + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + if int(preferredConn.waiters.Load()) >= p.queueLimitPerConn() { + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSConnQueueFull + } + preferredConn.waiters.Add(1) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + defer preferredConn.waiters.Add(-1) + waitStart := time.Now() + p.metrics.acquireQueueWaitTotal.Add(1) + + if err := preferredConn.acquire(ctx); err != nil { + if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + if p.shouldHealthCheckConn(preferredConn) { + if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + preferredConn.release() + preferredConn.close() + p.evictConn(accountID, preferredConn.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + + queueWait := time.Since(waitStart) + p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds()) + lease := &openAIWSConnLease{ + pool: p, + accountID: accountID, + conn: preferredConn, + queueWait: queueWait, + connPick: connPick, + reused: true, + } + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + + if preferredConnID != "" { + if conn, ok := ap.conns[preferredConnID]; ok && conn.tryAcquire() { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + if p.shouldHealthCheckConn(conn) { + if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + conn.close() + p.evictConn(accountID, conn.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true} + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + } + + best := p.pickLeastBusyConnLocked(ap, "") + if best != nil && best.tryAcquire() { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + if p.shouldHealthCheckConn(best) { + if err := best.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + best.close() + p.evictConn(accountID, best.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: best, connPick: connPick, reused: true} + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + for _, conn := range ap.conns { + if conn == nil || conn == best { + continue + } + if conn.tryAcquire() { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + if p.shouldHealthCheckConn(conn) { + if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + conn.close() + p.evictConn(accountID, conn.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true} + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + } + } + + if req.ForceNewConn && len(ap.conns)+ap.creating >= effectiveMaxConns { + if idle := p.pickOldestIdleConnLocked(ap); idle != nil { + delete(ap.conns, idle.id) + evicted = append(evicted, idle) + p.metrics.scaleDownTotal.Add(1) + } + } + + if len(ap.conns)+ap.creating < effectiveMaxConns { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.creating++ + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + + conn, dialErr := p.dialConn(ctx, req) + + ap = p.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.creating-- + if dialErr != nil { + ap.prewarmFails++ + ap.prewarmFailAt = time.Now() + ap.mu.Unlock() + return nil, dialErr + } + ap.conns[conn.id] = conn + ap.prewarmFails = 0 + ap.prewarmFailAt = time.Time{} + ap.mu.Unlock() + p.metrics.acquireCreateTotal.Add(1) + + if !conn.tryAcquire() { + if err := conn.acquire(ctx); err != nil { + conn.close() + p.evictConn(accountID, conn.id) + return nil, err + } + } + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick} + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + + if req.ForceNewConn { + p.recordConnPickDuration(time.Since(pickStartedAt)) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSConnQueueFull + } + + target := p.pickLeastBusyConnLocked(ap, req.PreferredConnID) + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + if target == nil { + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSConnClosed + } + if int(target.waiters.Load()) >= p.queueLimitPerConn() { + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSConnQueueFull + } + target.waiters.Add(1) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + defer target.waiters.Add(-1) + waitStart := time.Now() + p.metrics.acquireQueueWaitTotal.Add(1) + + if err := target.acquire(ctx); err != nil { + if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + if p.shouldHealthCheckConn(target) { + if err := target.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + target.release() + target.close() + p.evictConn(accountID, target.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + + queueWait := time.Since(waitStart) + p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds()) + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: target, queueWait: queueWait, connPick: connPick, reused: true} + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil +} + +func (p *openAIWSConnPool) recordConnPickDuration(duration time.Duration) { + if p == nil { + return + } + if duration < 0 { + duration = 0 + } + p.metrics.connPickTotal.Add(1) + p.metrics.connPickMs.Add(duration.Milliseconds()) +} + +func (p *openAIWSConnPool) pickOldestIdleConnLocked(ap *openAIWSAccountPool) *openAIWSConn { + if ap == nil || len(ap.conns) == 0 { + return nil + } + var oldest *openAIWSConn + for _, conn := range ap.conns { + if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) { + continue + } + if oldest == nil || conn.lastUsedAt().Before(oldest.lastUsedAt()) { + oldest = conn + } + } + return oldest +} + +func (p *openAIWSConnPool) getOrCreateAccountPool(accountID int64) *openAIWSAccountPool { + if p == nil || accountID <= 0 { + return nil + } + if existing, ok := p.accounts.Load(accountID); ok { + if ap, typed := existing.(*openAIWSAccountPool); typed && ap != nil { + return ap + } + } + ap := &openAIWSAccountPool{ + conns: make(map[string]*openAIWSConn), + pinnedConns: make(map[string]int), + } + actual, _ := p.accounts.LoadOrStore(accountID, ap) + if typed, ok := actual.(*openAIWSAccountPool); ok && typed != nil { + return typed + } + return ap +} + +// ensureAccountPoolLocked 兼容旧调用。 +func (p *openAIWSConnPool) ensureAccountPoolLocked(accountID int64) *openAIWSAccountPool { + return p.getOrCreateAccountPool(accountID) +} + +func (p *openAIWSConnPool) getAccountPool(accountID int64) (*openAIWSAccountPool, bool) { + if p == nil || accountID <= 0 { + return nil, false + } + value, ok := p.accounts.Load(accountID) + if !ok || value == nil { + return nil, false + } + ap, typed := value.(*openAIWSAccountPool) + return ap, typed && ap != nil +} + +func (p *openAIWSConnPool) isConnPinnedLocked(ap *openAIWSAccountPool, connID string) bool { + if ap == nil || connID == "" || len(ap.pinnedConns) == 0 { + return false + } + return ap.pinnedConns[connID] > 0 +} + +func (p *openAIWSConnPool) cleanupAccountLocked(ap *openAIWSAccountPool, now time.Time, maxConns int) []*openAIWSConn { + if ap == nil { + return nil + } + maxAge := p.maxConnAge() + + evicted := make([]*openAIWSConn, 0) + for id, conn := range ap.conns { + if conn == nil { + delete(ap.conns, id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, id) + } + continue + } + select { + case <-conn.closedCh: + delete(ap.conns, id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, id) + } + evicted = append(evicted, conn) + continue + default: + } + if p.isConnPinnedLocked(ap, id) { + continue + } + if maxAge > 0 && !conn.isLeased() && conn.age(now) > maxAge { + delete(ap.conns, id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, id) + } + evicted = append(evicted, conn) + } + } + + if maxConns <= 0 { + maxConns = p.maxConnsHardCap() + } + maxIdle := p.maxIdlePerAccount() + if maxIdle < 0 || maxIdle > maxConns { + maxIdle = maxConns + } + if maxIdle >= 0 && len(ap.conns) > maxIdle { + idleConns := make([]*openAIWSConn, 0, len(ap.conns)) + for id, conn := range ap.conns { + if conn == nil { + delete(ap.conns, id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, id) + } + continue + } + // 有等待者的连接不能在清理阶段被淘汰,否则等待中的 acquire 会收到 closed 错误。 + if conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) { + continue + } + idleConns = append(idleConns, conn) + } + sort.SliceStable(idleConns, func(i, j int) bool { + return idleConns[i].lastUsedAt().Before(idleConns[j].lastUsedAt()) + }) + redundant := len(ap.conns) - maxIdle + if redundant > len(idleConns) { + redundant = len(idleConns) + } + for i := 0; i < redundant; i++ { + conn := idleConns[i] + delete(ap.conns, conn.id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, conn.id) + } + evicted = append(evicted, conn) + } + if redundant > 0 { + p.metrics.scaleDownTotal.Add(int64(redundant)) + } + } + + return evicted +} + +func (p *openAIWSConnPool) pickLeastBusyConnLocked(ap *openAIWSAccountPool, preferredConnID string) *openAIWSConn { + if ap == nil || len(ap.conns) == 0 { + return nil + } + preferredConnID = stringsTrim(preferredConnID) + if preferredConnID != "" { + if conn, ok := ap.conns[preferredConnID]; ok { + return conn + } + } + var best *openAIWSConn + var bestWaiters int32 + var bestLastUsed time.Time + for _, conn := range ap.conns { + if conn == nil { + continue + } + waiters := conn.waiters.Load() + lastUsed := conn.lastUsedAt() + if best == nil || + waiters < bestWaiters || + (waiters == bestWaiters && lastUsed.Before(bestLastUsed)) { + best = conn + bestWaiters = waiters + bestLastUsed = lastUsed + } + } + return best +} + +func accountPoolLoadLocked(ap *openAIWSAccountPool) (inflight int, waiters int) { + if ap == nil { + return 0, 0 + } + for _, conn := range ap.conns { + if conn == nil { + continue + } + if conn.isLeased() { + inflight++ + } + waiters += int(conn.waiters.Load()) + } + return inflight, waiters +} + +// AccountPoolLoad 返回指定账号连接池的并发与排队快照。 +func (p *openAIWSConnPool) AccountPoolLoad(accountID int64) (inflight int, waiters int, conns int) { + if p == nil || accountID <= 0 { + return 0, 0, 0 + } + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + return 0, 0, 0 + } + ap.mu.Lock() + defer ap.mu.Unlock() + inflight, waiters = accountPoolLoadLocked(ap) + return inflight, waiters, len(ap.conns) +} + +func (p *openAIWSConnPool) ensureTargetIdleAsync(accountID int64) { + if p == nil || accountID <= 0 { + return + } + + var req openAIWSAcquireRequest + need := 0 + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + return + } + ap.mu.Lock() + defer ap.mu.Unlock() + if ap.lastAcquire == nil { + return + } + if ap.prewarmActive { + return + } + now := time.Now() + if !ap.prewarmUntil.IsZero() && now.Before(ap.prewarmUntil) { + return + } + if p.shouldSuppressPrewarmLocked(ap, now) { + return + } + effectiveMaxConns := p.maxConnsHardCap() + if ap.lastAcquire != nil && ap.lastAcquire.Account != nil { + effectiveMaxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account) + } + target := p.targetConnCountLocked(ap, effectiveMaxConns) + current := len(ap.conns) + ap.creating + if current >= target { + return + } + need = target - current + if need <= 0 { + return + } + req = cloneOpenAIWSAcquireRequest(*ap.lastAcquire) + ap.prewarmActive = true + if cooldown := p.prewarmCooldown(); cooldown > 0 { + ap.prewarmUntil = now.Add(cooldown) + } + ap.creating += need + p.metrics.scaleUpTotal.Add(int64(need)) + + go p.prewarmConns(accountID, req, need) +} + +func (p *openAIWSConnPool) targetConnCountLocked(ap *openAIWSAccountPool, maxConns int) int { + if ap == nil { + return 0 + } + + if maxConns <= 0 { + return 0 + } + + minIdle := p.minIdlePerAccount() + if minIdle < 0 { + minIdle = 0 + } + if minIdle > maxConns { + minIdle = maxConns + } + + inflight, waiters := accountPoolLoadLocked(ap) + utilization := p.targetUtilization() + demand := inflight + waiters + if demand <= 0 { + return minIdle + } + + target := 1 + if demand > 1 { + target = int(math.Ceil(float64(demand) / utilization)) + } + if waiters > 0 && target < len(ap.conns)+1 { + target = len(ap.conns) + 1 + } + if target < minIdle { + target = minIdle + } + if target > maxConns { + target = maxConns + } + return target +} + +func (p *openAIWSConnPool) prewarmConns(accountID int64, req openAIWSAcquireRequest, total int) { + defer func() { + if ap, ok := p.getAccountPool(accountID); ok && ap != nil { + ap.mu.Lock() + ap.prewarmActive = false + ap.mu.Unlock() + } + }() + + for i := 0; i < total; i++ { + ctx, cancel := context.WithTimeout(context.Background(), p.dialTimeout()+openAIWSConnPrewarmExtraDelay) + conn, err := p.dialConn(ctx, req) + cancel() + + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + if conn != nil { + conn.close() + } + return + } + ap.mu.Lock() + if ap.creating > 0 { + ap.creating-- + } + if err != nil { + ap.prewarmFails++ + ap.prewarmFailAt = time.Now() + ap.mu.Unlock() + continue + } + if len(ap.conns) >= p.effectiveMaxConnsByAccount(req.Account) { + ap.mu.Unlock() + conn.close() + continue + } + ap.conns[conn.id] = conn + ap.prewarmFails = 0 + ap.prewarmFailAt = time.Time{} + ap.mu.Unlock() + } +} + +func (p *openAIWSConnPool) evictConn(accountID int64, connID string) { + if p == nil || accountID <= 0 || stringsTrim(connID) == "" { + return + } + var conn *openAIWSConn + ap, ok := p.getAccountPool(accountID) + if ok && ap != nil { + ap.mu.Lock() + if c, exists := ap.conns[connID]; exists { + conn = c + delete(ap.conns, connID) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, connID) + } + } + ap.mu.Unlock() + } + if conn != nil { + conn.close() + } +} + +func (p *openAIWSConnPool) PinConn(accountID int64, connID string) bool { + if p == nil || accountID <= 0 { + return false + } + connID = stringsTrim(connID) + if connID == "" { + return false + } + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + if _, exists := ap.conns[connID]; !exists { + return false + } + if ap.pinnedConns == nil { + ap.pinnedConns = make(map[string]int) + } + ap.pinnedConns[connID]++ + return true +} + +func (p *openAIWSConnPool) UnpinConn(accountID int64, connID string) { + if p == nil || accountID <= 0 { + return + } + connID = stringsTrim(connID) + if connID == "" { + return + } + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + return + } + ap.mu.Lock() + defer ap.mu.Unlock() + if len(ap.pinnedConns) == 0 { + return + } + count := ap.pinnedConns[connID] + if count <= 1 { + delete(ap.pinnedConns, connID) + return + } + ap.pinnedConns[connID] = count - 1 +} + +func (p *openAIWSConnPool) dialConn(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConn, error) { + if p == nil || p.clientDialer == nil { + return nil, errors.New("openai ws client dialer is nil") + } + conn, status, handshakeHeaders, err := p.clientDialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL) + if err != nil { + return nil, &openAIWSDialError{ + StatusCode: status, + ResponseHeaders: cloneHeader(handshakeHeaders), + Err: err, + } + } + if conn == nil { + return nil, &openAIWSDialError{ + StatusCode: status, + ResponseHeaders: cloneHeader(handshakeHeaders), + Err: errors.New("openai ws dialer returned nil connection"), + } + } + id := p.nextConnID(req.Account.ID) + return newOpenAIWSConn(id, req.Account.ID, conn, handshakeHeaders), nil +} + +func (p *openAIWSConnPool) nextConnID(accountID int64) string { + seq := p.seq.Add(1) + buf := make([]byte, 0, 32) + buf = append(buf, "oa_ws_"...) + buf = strconv.AppendInt(buf, accountID, 10) + buf = append(buf, '_') + buf = strconv.AppendUint(buf, seq, 10) + return string(buf) +} + +func (p *openAIWSConnPool) shouldHealthCheckConn(conn *openAIWSConn) bool { + if conn == nil { + return false + } + return conn.idleDuration(time.Now()) >= openAIWSConnHealthCheckIdle +} + +func (p *openAIWSConnPool) maxConnsHardCap() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 { + return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount + } + return 8 +} + +func (p *openAIWSConnPool) dynamicMaxConnsEnabled() bool { + if p != nil && p.cfg != nil { + return p.cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled + } + return false +} + +func (p *openAIWSConnPool) modeRouterV2Enabled() bool { + if p != nil && p.cfg != nil { + return p.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled + } + return false +} + +func (p *openAIWSConnPool) maxConnsFactorByAccount(account *Account) float64 { + if p == nil || p.cfg == nil || account == nil { + return 1.0 + } + switch account.Type { + case AccountTypeOAuth: + if p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor > 0 { + return p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor + } + case AccountTypeAPIKey: + if p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor > 0 { + return p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor + } + } + return 1.0 +} + +func (p *openAIWSConnPool) effectiveMaxConnsByAccount(account *Account) int { + hardCap := p.maxConnsHardCap() + if hardCap <= 0 { + return 0 + } + if p.modeRouterV2Enabled() { + if account == nil { + return hardCap + } + if account.Concurrency <= 0 { + return 0 + } + return account.Concurrency + } + if account == nil || !p.dynamicMaxConnsEnabled() { + return hardCap + } + if account.Concurrency <= 0 { + // 0/-1 等“无限制”并发场景下,仍由全局硬上限兜底。 + return hardCap + } + factor := p.maxConnsFactorByAccount(account) + if factor <= 0 { + factor = 1.0 + } + effective := int(math.Ceil(float64(account.Concurrency) * factor)) + if effective < 1 { + effective = 1 + } + if effective > hardCap { + effective = hardCap + } + return effective +} + +func (p *openAIWSConnPool) minIdlePerAccount() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MinIdlePerAccount >= 0 { + return p.cfg.Gateway.OpenAIWS.MinIdlePerAccount + } + return 0 +} + +func (p *openAIWSConnPool) maxIdlePerAccount() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount >= 0 { + return p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount + } + return 4 +} + +func (p *openAIWSConnPool) maxConnAge() time.Duration { + return openAIWSConnMaxAge +} + +func (p *openAIWSConnPool) queueLimitPerConn() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.QueueLimitPerConn > 0 { + return p.cfg.Gateway.OpenAIWS.QueueLimitPerConn + } + return 256 +} + +func (p *openAIWSConnPool) targetUtilization() float64 { + if p != nil && p.cfg != nil { + ratio := p.cfg.Gateway.OpenAIWS.PoolTargetUtilization + if ratio > 0 && ratio <= 1 { + return ratio + } + } + return 0.7 +} + +func (p *openAIWSConnPool) prewarmCooldown() time.Duration { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS > 0 { + return time.Duration(p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS) * time.Millisecond + } + return 0 +} + +func (p *openAIWSConnPool) shouldSuppressPrewarmLocked(ap *openAIWSAccountPool, now time.Time) bool { + if ap == nil { + return true + } + if ap.prewarmFails <= 0 { + return false + } + if ap.prewarmFailAt.IsZero() { + ap.prewarmFails = 0 + return false + } + if now.Sub(ap.prewarmFailAt) > openAIWSPrewarmFailureWindow { + ap.prewarmFails = 0 + ap.prewarmFailAt = time.Time{} + return false + } + return ap.prewarmFails >= openAIWSPrewarmFailureSuppress +} + +func (p *openAIWSConnPool) dialTimeout() time.Duration { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { + return time.Duration(p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second + } + return 10 * time.Second +} + +func cloneOpenAIWSAcquireRequest(req openAIWSAcquireRequest) openAIWSAcquireRequest { + copied := req + copied.Headers = cloneHeader(req.Headers) + copied.WSURL = stringsTrim(req.WSURL) + copied.ProxyURL = stringsTrim(req.ProxyURL) + copied.PreferredConnID = stringsTrim(req.PreferredConnID) + return copied +} + +func cloneOpenAIWSAcquireRequestPtr(req *openAIWSAcquireRequest) *openAIWSAcquireRequest { + if req == nil { + return nil + } + copied := cloneOpenAIWSAcquireRequest(*req) + return &copied +} + +func cloneHeader(src http.Header) http.Header { + if src == nil { + return nil + } + dst := make(http.Header, len(src)) + for k, vals := range src { + if len(vals) == 0 { + dst[k] = nil + continue + } + copied := make([]string, len(vals)) + copy(copied, vals) + dst[k] = copied + } + return dst +} + +func closeOpenAIWSConns(conns []*openAIWSConn) { + if len(conns) == 0 { + return + } + for _, conn := range conns { + if conn == nil { + continue + } + conn.close() + } +} + +func stringsTrim(value string) string { + return strings.TrimSpace(value) +} diff --git a/backend/internal/service/openai_ws_pool_benchmark_test.go b/backend/internal/service/openai_ws_pool_benchmark_test.go new file mode 100644 index 00000000..bff74b62 --- /dev/null +++ b/backend/internal/service/openai_ws_pool_benchmark_test.go @@ -0,0 +1,58 @@ +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +func BenchmarkOpenAIWSPoolAcquire(b *testing.B) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 256 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(&openAIWSCountingDialer{}) + + account := &Account{ID: 1001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + req := openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ctx := context.Background() + + lease, err := pool.Acquire(ctx, req) + if err != nil { + b.Fatalf("warm acquire failed: %v", err) + } + lease.Release() + + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + var ( + got *openAIWSConnLease + acquireErr error + ) + for retry := 0; retry < 3; retry++ { + got, acquireErr = pool.Acquire(ctx, req) + if acquireErr == nil { + break + } + if !errors.Is(acquireErr, errOpenAIWSConnClosed) { + break + } + } + if acquireErr != nil { + b.Fatalf("acquire failed: %v", acquireErr) + } + got.Release() + } + }) +} diff --git a/backend/internal/service/openai_ws_pool_test.go b/backend/internal/service/openai_ws_pool_test.go new file mode 100644 index 00000000..b2683ee0 --- /dev/null +++ b/backend/internal/service/openai_ws_pool_test.go @@ -0,0 +1,1709 @@ +package service + +import ( + "context" + "errors" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSConnPool_CleanupStaleAndTrimIdle(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + pool := newOpenAIWSConnPool(cfg) + + accountID := int64(10) + ap := pool.getOrCreateAccountPool(accountID) + + stale := newOpenAIWSConn("stale", accountID, nil, nil) + stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + + idleOld := newOpenAIWSConn("idle_old", accountID, nil, nil) + idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano()) + + idleNew := newOpenAIWSConn("idle_new", accountID, nil, nil) + idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano()) + + ap.conns[stale.id] = stale + ap.conns[idleOld.id] = idleOld + ap.conns[idleNew.id] = idleNew + + evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap()) + closeOpenAIWSConns(evicted) + + require.Nil(t, ap.conns["stale"], "stale connection should be rotated") + require.Nil(t, ap.conns["idle_old"], "old idle should be trimmed by max_idle") + require.NotNil(t, ap.conns["idle_new"], "newer idle should be kept") +} + +func TestOpenAIWSConnPool_NextConnIDFormat(t *testing.T) { + pool := newOpenAIWSConnPool(&config.Config{}) + id1 := pool.nextConnID(42) + id2 := pool.nextConnID(42) + + require.True(t, strings.HasPrefix(id1, "oa_ws_42_")) + require.True(t, strings.HasPrefix(id2, "oa_ws_42_")) + require.NotEqual(t, id1, id2) + require.Equal(t, "oa_ws_42_1", id1) + require.Equal(t, "oa_ws_42_2", id2) +} + +func TestOpenAIWSConnPool_AcquireCleanupInterval(t *testing.T) { + require.Equal(t, 3*time.Second, openAIWSAcquireCleanupInterval) + require.Less(t, openAIWSAcquireCleanupInterval, openAIWSBackgroundSweepTicker) +} + +func TestOpenAIWSConnLease_WriteJSONAndGuards(t *testing.T) { + conn := newOpenAIWSConn("lease_write", 1, &openAIWSFakeConn{}, nil) + lease := &openAIWSConnLease{conn: conn} + require.NoError(t, lease.WriteJSON(map[string]any{"type": "response.create"}, 0)) + + var nilLease *openAIWSConnLease + err := nilLease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + + err = (&openAIWSConnLease{}).WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestOpenAIWSConn_WriteJSONWithTimeout_NilParentContextUsesBackground(t *testing.T) { + probe := &openAIWSContextProbeConn{} + conn := newOpenAIWSConn("ctx_probe", 1, probe, nil) + require.NoError(t, conn.writeJSONWithTimeout(context.Background(), map[string]any{"type": "response.create"}, 0)) + require.NotNil(t, probe.lastWriteCtx) +} + +func TestOpenAIWSConnPool_TargetConnCountAdaptive(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 6 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.5 + + pool := newOpenAIWSConnPool(cfg) + ap := pool.getOrCreateAccountPool(88) + + conn1 := newOpenAIWSConn("c1", 88, nil, nil) + conn2 := newOpenAIWSConn("c2", 88, nil, nil) + require.True(t, conn1.tryAcquire()) + require.True(t, conn2.tryAcquire()) + conn1.waiters.Store(1) + conn2.waiters.Store(1) + + ap.conns[conn1.id] = conn1 + ap.conns[conn2.id] = conn2 + + target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap()) + require.Equal(t, 6, target, "应按 inflight+waiters 与 target_utilization 自适应扩容到上限") + + conn1.release() + conn2.release() + conn1.waiters.Store(0) + conn2.waiters.Store(0) + target = pool.targetConnCountLocked(ap, pool.maxConnsHardCap()) + require.Equal(t, 1, target, "低负载时应缩回到最小空闲连接") +} + +func TestOpenAIWSConnPool_TargetConnCountMinIdleZero(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 + + pool := newOpenAIWSConnPool(cfg) + ap := pool.getOrCreateAccountPool(66) + + target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap()) + require.Equal(t, 0, target, "min_idle=0 且无负载时应允许缩容到 0") +} + +func TestOpenAIWSConnPool_EnsureTargetIdleAsync(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(&openAIWSFakeDialer{}) + + accountID := int64(77) + account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ap.mu.Unlock() + + pool.ensureTargetIdleAsync(accountID) + + require.Eventually(t, func() bool { + ap, ok := pool.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + return len(ap.conns) >= 2 + }, 2*time.Second, 20*time.Millisecond) + + metrics := pool.SnapshotMetrics() + require.GreaterOrEqual(t, metrics.ScaleUpTotal, int64(2)) +} + +func TestOpenAIWSConnPool_EnsureTargetIdleAsyncCooldown(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 500 + + pool := newOpenAIWSConnPool(cfg) + dialer := &openAIWSCountingDialer{} + pool.setClientDialerForTest(dialer) + + accountID := int64(178) + account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ap.mu.Unlock() + + pool.ensureTargetIdleAsync(accountID) + require.Eventually(t, func() bool { + ap, ok := pool.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + return len(ap.conns) >= 2 && !ap.prewarmActive + }, 2*time.Second, 20*time.Millisecond) + firstDialCount := dialer.DialCount() + require.GreaterOrEqual(t, firstDialCount, 2) + + // 人工制造缺口触发新一轮预热需求。 + ap, ok := pool.getAccountPool(accountID) + require.True(t, ok) + require.NotNil(t, ap) + ap.mu.Lock() + for id := range ap.conns { + delete(ap.conns, id) + break + } + ap.mu.Unlock() + + pool.ensureTargetIdleAsync(accountID) + time.Sleep(120 * time.Millisecond) + require.Equal(t, firstDialCount, dialer.DialCount(), "cooldown 窗口内不应再次触发预热") + + time.Sleep(450 * time.Millisecond) + pool.ensureTargetIdleAsync(accountID) + require.Eventually(t, func() bool { + return dialer.DialCount() > firstDialCount + }, 2*time.Second, 20*time.Millisecond) +} + +func TestOpenAIWSConnPool_EnsureTargetIdleAsyncFailureSuppress(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 0 + + pool := newOpenAIWSConnPool(cfg) + dialer := &openAIWSAlwaysFailDialer{} + pool.setClientDialerForTest(dialer) + + accountID := int64(279) + account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ap.mu.Unlock() + + pool.ensureTargetIdleAsync(accountID) + require.Eventually(t, func() bool { + ap, ok := pool.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + return !ap.prewarmActive + }, 2*time.Second, 20*time.Millisecond) + + pool.ensureTargetIdleAsync(accountID) + require.Eventually(t, func() bool { + ap, ok := pool.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + return !ap.prewarmActive + }, 2*time.Second, 20*time.Millisecond) + require.Equal(t, 2, dialer.DialCount()) + + // 连续失败达到阈值后,新的预热触发应被抑制,不再继续拨号。 + pool.ensureTargetIdleAsync(accountID) + time.Sleep(120 * time.Millisecond) + require.Equal(t, 2, dialer.DialCount()) +} + +func TestOpenAIWSConnPool_AcquireQueueWaitMetrics(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4 + + pool := newOpenAIWSConnPool(cfg) + accountID := int64(99) + account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + conn := newOpenAIWSConn("busy", accountID, &openAIWSFakeConn{}, nil) + require.True(t, conn.tryAcquire()) // 占用连接,触发后续排队 + + ap := pool.ensureAccountPoolLocked(accountID) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ap.mu.Unlock() + + go func() { + time.Sleep(60 * time.Millisecond) + conn.release() + }() + + lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.NoError(t, err) + require.NotNil(t, lease) + require.True(t, lease.Reused()) + require.GreaterOrEqual(t, lease.QueueWaitDuration(), 50*time.Millisecond) + lease.Release() + + metrics := pool.SnapshotMetrics() + require.GreaterOrEqual(t, metrics.AcquireQueueWaitTotal, int64(1)) + require.Greater(t, metrics.AcquireQueueWaitMsTotal, int64(0)) + require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1)) +} + +func TestOpenAIWSConnPool_ForceNewConnSkipsReuse(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + + pool := newOpenAIWSConnPool(cfg) + dialer := &openAIWSCountingDialer{} + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + lease1, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + lease2, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + ForceNewConn: true, + }) + require.NoError(t, err) + require.NotNil(t, lease2) + lease2.Release() + + require.Equal(t, 2, dialer.DialCount(), "ForceNewConn=true 时应跳过空闲连接复用并新建连接") +} + +func TestOpenAIWSConnPool_AcquireForcePreferredConnUnavailable(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + + pool := newOpenAIWSConnPool(cfg) + account := &Account{ID: 124, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(account.ID) + otherConn := newOpenAIWSConn("other_conn", account.ID, &openAIWSFakeConn{}, nil) + ap.mu.Lock() + ap.conns[otherConn.id] = otherConn + ap.mu.Unlock() + + _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + ForcePreferredConn: true, + }) + require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable) + + _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + PreferredConnID: "missing_conn", + ForcePreferredConn: true, + }) + require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable) +} + +func TestOpenAIWSConnPool_AcquireForcePreferredConnQueuesOnPreferredOnly(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4 + + pool := newOpenAIWSConnPool(cfg) + account := &Account{ID: 125, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(account.ID) + preferredConn := newOpenAIWSConn("preferred_conn", account.ID, &openAIWSFakeConn{}, nil) + otherConn := newOpenAIWSConn("other_conn_idle", account.ID, &openAIWSFakeConn{}, nil) + require.True(t, preferredConn.tryAcquire(), "先占用 preferred 连接,触发排队获取") + ap.mu.Lock() + ap.conns[preferredConn.id] = preferredConn + ap.conns[otherConn.id] = otherConn + ap.lastCleanupAt = time.Now() + ap.mu.Unlock() + + go func() { + time.Sleep(60 * time.Millisecond) + preferredConn.release() + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + lease, err := pool.Acquire(ctx, openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + PreferredConnID: preferredConn.id, + ForcePreferredConn: true, + }) + require.NoError(t, err) + require.NotNil(t, lease) + require.Equal(t, preferredConn.id, lease.ConnID(), "严格模式应只等待并复用 preferred 连接,不可漂移") + require.GreaterOrEqual(t, lease.QueueWaitDuration(), 40*time.Millisecond) + lease.Release() + require.True(t, otherConn.tryAcquire(), "other 连接不应被严格模式抢占") + otherConn.release() +} + +func TestOpenAIWSConnPool_AcquireForcePreferredConnDirectAndQueueFull(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1 + + pool := newOpenAIWSConnPool(cfg) + account := &Account{ID: 127, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(account.ID) + preferredConn := newOpenAIWSConn("preferred_conn_direct", account.ID, &openAIWSFakeConn{}, nil) + otherConn := newOpenAIWSConn("other_conn_direct", account.ID, &openAIWSFakeConn{}, nil) + ap.mu.Lock() + ap.conns[preferredConn.id] = preferredConn + ap.conns[otherConn.id] = otherConn + ap.lastCleanupAt = time.Now() + ap.mu.Unlock() + + lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + PreferredConnID: preferredConn.id, + ForcePreferredConn: true, + }) + require.NoError(t, err) + require.Equal(t, preferredConn.id, lease.ConnID(), "preferred 空闲时应直接命中") + lease.Release() + + require.True(t, preferredConn.tryAcquire()) + preferredConn.waiters.Store(1) + _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + PreferredConnID: preferredConn.id, + ForcePreferredConn: true, + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "严格模式下队列满应直接失败,不得漂移") + preferredConn.waiters.Store(0) + preferredConn.release() +} + +func TestOpenAIWSConnPool_CleanupSkipsPinnedConn(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0 + + pool := newOpenAIWSConnPool(cfg) + accountID := int64(126) + ap := pool.getOrCreateAccountPool(accountID) + pinnedConn := newOpenAIWSConn("pinned_conn", accountID, &openAIWSFakeConn{}, nil) + idleConn := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil) + ap.mu.Lock() + ap.conns[pinnedConn.id] = pinnedConn + ap.conns[idleConn.id] = idleConn + ap.mu.Unlock() + + require.True(t, pool.PinConn(accountID, pinnedConn.id)) + evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap()) + closeOpenAIWSConns(evicted) + + ap.mu.Lock() + _, pinnedExists := ap.conns[pinnedConn.id] + _, idleExists := ap.conns[idleConn.id] + ap.mu.Unlock() + require.True(t, pinnedExists, "被 active ingress 绑定的连接不应被 cleanup 回收") + require.False(t, idleExists, "非绑定的空闲连接应被回收") + + pool.UnpinConn(accountID, pinnedConn.id) + evicted = pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap()) + closeOpenAIWSConns(evicted) + ap.mu.Lock() + _, pinnedExists = ap.conns[pinnedConn.id] + ap.mu.Unlock() + require.False(t, pinnedExists, "解绑后连接应可被正常回收") +} + +func TestOpenAIWSConnPool_PinUnpinConnBranches(t *testing.T) { + var nilPool *openAIWSConnPool + require.False(t, nilPool.PinConn(1, "x")) + nilPool.UnpinConn(1, "x") + + cfg := &config.Config{} + pool := newOpenAIWSConnPool(cfg) + accountID := int64(128) + ap := &openAIWSAccountPool{ + conns: map[string]*openAIWSConn{}, + } + pool.accounts.Store(accountID, ap) + + require.False(t, pool.PinConn(0, "x")) + require.False(t, pool.PinConn(999, "x")) + require.False(t, pool.PinConn(accountID, "")) + require.False(t, pool.PinConn(accountID, "missing")) + + conn := newOpenAIWSConn("pin_refcount", accountID, &openAIWSFakeConn{}, nil) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.mu.Unlock() + require.True(t, pool.PinConn(accountID, conn.id)) + require.True(t, pool.PinConn(accountID, conn.id)) + + ap.mu.Lock() + require.Equal(t, 2, ap.pinnedConns[conn.id]) + ap.mu.Unlock() + + pool.UnpinConn(accountID, conn.id) + ap.mu.Lock() + require.Equal(t, 1, ap.pinnedConns[conn.id]) + ap.mu.Unlock() + + pool.UnpinConn(accountID, conn.id) + ap.mu.Lock() + _, exists := ap.pinnedConns[conn.id] + ap.mu.Unlock() + require.False(t, exists) + + pool.UnpinConn(accountID, conn.id) + pool.UnpinConn(accountID, "") + pool.UnpinConn(0, conn.id) + pool.UnpinConn(999, conn.id) +} + +func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true + cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0 + cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6 + + pool := newOpenAIWSConnPool(cfg) + + oauthHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 10} + require.Equal(t, 8, pool.effectiveMaxConnsByAccount(oauthHigh), "应受全局硬上限约束") + + oauthLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 3} + require.Equal(t, 3, pool.effectiveMaxConnsByAccount(oauthLow)) + + apiKeyHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 10} + require.Equal(t, 6, pool.effectiveMaxConnsByAccount(apiKeyHigh), "API Key 应按系数缩放") + + apiKeyLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1} + require.Equal(t, 1, pool.effectiveMaxConnsByAccount(apiKeyLow), "最小值应保持为 1") + + unlimited := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0} + require.Equal(t, 8, pool.effectiveMaxConnsByAccount(unlimited), "无限并发应回退到全局硬上限") + + require.Equal(t, 8, pool.effectiveMaxConnsByAccount(nil), "缺少账号上下文应回退到全局硬上限") +} + +func TestOpenAIWSConnPool_EffectiveMaxConnsDisabledFallbackHardCap(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false + cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0 + cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 1.0 + + pool := newOpenAIWSConnPool(cfg) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 2} + require.Equal(t, 8, pool.effectiveMaxConnsByAccount(account), "关闭动态模式后应保持旧行为") +} + +func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount_ModeRouterV2UsesAccountConcurrency(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true + cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0.3 + cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6 + + pool := newOpenAIWSConnPool(cfg) + + high := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 20} + require.Equal(t, 20, pool.effectiveMaxConnsByAccount(high), "v2 路径应直接使用账号并发数作为池上限") + + nonPositive := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 0} + require.Equal(t, 0, pool.effectiveMaxConnsByAccount(nonPositive), "并发数<=0 时应不可调度") +} + +func TestOpenAIWSConnPool_AcquireRejectsWhenEffectiveMaxConnsIsZero(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + pool := newOpenAIWSConnPool(cfg) + + account := &Account{ID: 901, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0} + _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull) +} + +func TestOpenAIWSConnLease_ReadMessageWithContextTimeout_PerRead(t *testing.T) { + conn := newOpenAIWSConn("timeout", 1, &openAIWSBlockingConn{readDelay: 80 * time.Millisecond}, nil) + lease := &openAIWSConnLease{conn: conn} + + _, err := lease.ReadMessageWithContextTimeout(context.Background(), 20*time.Millisecond) + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + + payload, err := lease.ReadMessageWithContextTimeout(context.Background(), 150*time.Millisecond) + require.NoError(t, err) + require.Contains(t, string(payload), "response.completed") + + parentCtx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = lease.ReadMessageWithContextTimeout(parentCtx, 150*time.Millisecond) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) +} + +func TestOpenAIWSConnLease_WriteJSONWithContextTimeout_RespectsParentContext(t *testing.T) { + conn := newOpenAIWSConn("write_timeout_ctx", 1, &openAIWSWriteBlockingConn{}, nil) + lease := &openAIWSConnLease{conn: conn} + + parentCtx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + start := time.Now() + err := lease.WriteJSONWithContextTimeout(parentCtx, map[string]any{"type": "response.create"}, 2*time.Minute) + elapsed := time.Since(start) + + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + require.Less(t, elapsed, 200*time.Millisecond) +} + +func TestOpenAIWSConnLease_PingWithTimeout(t *testing.T) { + conn := newOpenAIWSConn("ping_ok", 1, &openAIWSFakeConn{}, nil) + lease := &openAIWSConnLease{conn: conn} + require.NoError(t, lease.PingWithTimeout(50*time.Millisecond)) + + var nilLease *openAIWSConnLease + err := nilLease.PingWithTimeout(50 * time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestOpenAIWSConn_ReadAndWriteCanProceedConcurrently(t *testing.T) { + conn := newOpenAIWSConn("full_duplex", 1, &openAIWSBlockingConn{readDelay: 120 * time.Millisecond}, nil) + + readDone := make(chan error, 1) + go func() { + _, err := conn.readMessageWithContextTimeout(context.Background(), 200*time.Millisecond) + readDone <- err + }() + + // 让读取先占用 readMu。 + time.Sleep(20 * time.Millisecond) + + start := time.Now() + err := conn.pingWithTimeout(50 * time.Millisecond) + elapsed := time.Since(start) + + require.NoError(t, err) + require.Less(t, elapsed, 80*time.Millisecond, "写路径不应被读锁长期阻塞") + require.NoError(t, <-readDone) +} + +func TestOpenAIWSConnPool_BackgroundPingSweep_EvictsDeadIdleConn(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + pool := newOpenAIWSConnPool(cfg) + + accountID := int64(301) + ap := pool.getOrCreateAccountPool(accountID) + conn := newOpenAIWSConn("dead_idle", accountID, &openAIWSPingFailConn{}, nil) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.mu.Unlock() + + pool.runBackgroundPingSweep() + + ap.mu.Lock() + _, exists := ap.conns[conn.id] + ap.mu.Unlock() + require.False(t, exists, "后台 ping 失败的空闲连接应被回收") +} + +func TestOpenAIWSConnPool_BackgroundCleanupSweep_WithoutAcquire(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + pool := newOpenAIWSConnPool(cfg) + + accountID := int64(302) + ap := pool.getOrCreateAccountPool(accountID) + stale := newOpenAIWSConn("stale_bg", accountID, &openAIWSFakeConn{}, nil) + stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + ap.mu.Lock() + ap.conns[stale.id] = stale + ap.mu.Unlock() + + pool.runBackgroundCleanupSweep(time.Now()) + + ap.mu.Lock() + _, exists := ap.conns[stale.id] + ap.mu.Unlock() + require.False(t, exists, "后台清理应在无新 acquire 时也回收过期连接") +} + +func TestOpenAIWSConnPool_BackgroundWorkerGuardBranches(t *testing.T) { + var nilPool *openAIWSConnPool + require.NotPanics(t, func() { + nilPool.startBackgroundWorkers() + nilPool.runBackgroundPingWorker() + nilPool.runBackgroundPingSweep() + _ = nilPool.snapshotIdleConnsForPing() + nilPool.runBackgroundCleanupWorker() + nilPool.runBackgroundCleanupSweep(time.Now()) + }) + + poolNoStop := &openAIWSConnPool{} + require.NotPanics(t, func() { + poolNoStop.startBackgroundWorkers() + }) + + poolStopPing := &openAIWSConnPool{workerStopCh: make(chan struct{})} + pingDone := make(chan struct{}) + go func() { + poolStopPing.runBackgroundPingWorker() + close(pingDone) + }() + close(poolStopPing.workerStopCh) + select { + case <-pingDone: + case <-time.After(500 * time.Millisecond): + t.Fatal("runBackgroundPingWorker 未在 stop 信号后退出") + } + + poolStopCleanup := &openAIWSConnPool{workerStopCh: make(chan struct{})} + cleanupDone := make(chan struct{}) + go func() { + poolStopCleanup.runBackgroundCleanupWorker() + close(cleanupDone) + }() + close(poolStopCleanup.workerStopCh) + select { + case <-cleanupDone: + case <-time.After(500 * time.Millisecond): + t.Fatal("runBackgroundCleanupWorker 未在 stop 信号后退出") + } +} + +func TestOpenAIWSConnPool_SnapshotIdleConnsForPing_SkipsInvalidEntries(t *testing.T) { + pool := &openAIWSConnPool{} + pool.accounts.Store("invalid-key", &openAIWSAccountPool{}) + pool.accounts.Store(int64(123), "invalid-value") + + accountID := int64(123) + ap := &openAIWSAccountPool{ + conns: make(map[string]*openAIWSConn), + } + ap.conns["nil_conn"] = nil + + leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil) + require.True(t, leased.tryAcquire()) + ap.conns[leased.id] = leased + + waiting := newOpenAIWSConn("waiting", accountID, &openAIWSFakeConn{}, nil) + waiting.waiters.Store(1) + ap.conns[waiting.id] = waiting + + idle := newOpenAIWSConn("idle", accountID, &openAIWSFakeConn{}, nil) + ap.conns[idle.id] = idle + + pool.accounts.Store(accountID, ap) + candidates := pool.snapshotIdleConnsForPing() + require.Len(t, candidates, 1) + require.Equal(t, idle.id, candidates[0].conn.id) +} + +func TestOpenAIWSConnPool_RunBackgroundCleanupSweep_SkipsInvalidAndUsesAccountCap(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true + + pool := &openAIWSConnPool{cfg: cfg} + pool.accounts.Store("bad-key", "bad-value") + + accountID := int64(2026) + ap := &openAIWSAccountPool{ + conns: make(map[string]*openAIWSConn), + } + ap.conns["nil_conn"] = nil + stale := newOpenAIWSConn("stale_bg_cleanup", accountID, &openAIWSFakeConn{}, nil) + stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + ap.conns[stale.id] = stale + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: &Account{ + ID: accountID, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + }, + } + pool.accounts.Store(accountID, ap) + + now := time.Now() + require.NotPanics(t, func() { + pool.runBackgroundCleanupSweep(now) + }) + + ap.mu.Lock() + _, nilConnExists := ap.conns["nil_conn"] + _, exists := ap.conns[stale.id] + lastCleanupAt := ap.lastCleanupAt + ap.mu.Unlock() + + require.False(t, nilConnExists, "后台清理应移除无效 nil 连接条目") + require.False(t, exists, "后台清理应清理过期连接") + require.Equal(t, now, lastCleanupAt) +} + +func TestOpenAIWSConnPool_QueueLimitPerConn_DefaultAndConfigured(t *testing.T) { + var nilPool *openAIWSConnPool + require.Equal(t, 256, nilPool.queueLimitPerConn()) + + pool := &openAIWSConnPool{cfg: &config.Config{}} + require.Equal(t, 256, pool.queueLimitPerConn()) + + pool.cfg.Gateway.OpenAIWS.QueueLimitPerConn = 9 + require.Equal(t, 9, pool.queueLimitPerConn()) +} + +func TestOpenAIWSConnPool_Close(t *testing.T) { + cfg := &config.Config{} + pool := newOpenAIWSConnPool(cfg) + + // Close 应该可以安全调用 + pool.Close() + + // workerStopCh 应已关闭 + select { + case <-pool.workerStopCh: + // 预期:channel 已关闭 + default: + t.Fatal("Close 后 workerStopCh 应已关闭") + } + + // 多次调用 Close 不应 panic + pool.Close() + + // nil pool 调用 Close 不应 panic + var nilPool *openAIWSConnPool + nilPool.Close() +} + +func TestOpenAIWSDialError_ErrorAndUnwrap(t *testing.T) { + baseErr := errors.New("boom") + dialErr := &openAIWSDialError{StatusCode: 502, Err: baseErr} + require.Contains(t, dialErr.Error(), "status=502") + require.ErrorIs(t, dialErr.Unwrap(), baseErr) + + noStatus := &openAIWSDialError{Err: baseErr} + require.Contains(t, noStatus.Error(), "boom") + + var nilDialErr *openAIWSDialError + require.Equal(t, "", nilDialErr.Error()) + require.NoError(t, nilDialErr.Unwrap()) +} + +func TestOpenAIWSConnLease_ReadWriteHelpersAndConnStats(t *testing.T) { + conn := newOpenAIWSConn("helper_conn", 1, &openAIWSFakeConn{}, http.Header{ + "X-Test": []string{" value "}, + }) + lease := &openAIWSConnLease{conn: conn} + + require.NoError(t, lease.WriteJSONContext(context.Background(), map[string]any{"type": "response.create"})) + payload, err := lease.ReadMessage(100 * time.Millisecond) + require.NoError(t, err) + require.Contains(t, string(payload), "response.completed") + + payload, err = lease.ReadMessageContext(context.Background()) + require.NoError(t, err) + require.Contains(t, string(payload), "response.completed") + + payload, err = conn.readMessageWithTimeout(100 * time.Millisecond) + require.NoError(t, err) + require.Contains(t, string(payload), "response.completed") + + require.Equal(t, "value", conn.handshakeHeader(" X-Test ")) + require.NotZero(t, conn.createdAt()) + require.NotZero(t, conn.lastUsedAt()) + require.GreaterOrEqual(t, conn.age(time.Now()), time.Duration(0)) + require.GreaterOrEqual(t, conn.idleDuration(time.Now()), time.Duration(0)) + require.False(t, conn.isLeased()) + + // 覆盖空上下文路径 + _, err = conn.readMessage(context.Background()) + require.NoError(t, err) + + // 覆盖 nil 保护分支 + var nilConn *openAIWSConn + require.ErrorIs(t, nilConn.writeJSONWithTimeout(context.Background(), map[string]any{}, time.Second), errOpenAIWSConnClosed) + _, err = nilConn.readMessageWithTimeout(10 * time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = nilConn.readMessageWithContextTimeout(context.Background(), 10*time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestOpenAIWSConnPool_PickOldestIdleAndAccountPoolLoad(t *testing.T) { + pool := &openAIWSConnPool{} + accountID := int64(404) + ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}} + + idleOld := newOpenAIWSConn("idle_old", accountID, &openAIWSFakeConn{}, nil) + idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano()) + idleNew := newOpenAIWSConn("idle_new", accountID, &openAIWSFakeConn{}, nil) + idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano()) + leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil) + require.True(t, leased.tryAcquire()) + leased.waiters.Store(2) + + ap.conns[idleOld.id] = idleOld + ap.conns[idleNew.id] = idleNew + ap.conns[leased.id] = leased + + oldest := pool.pickOldestIdleConnLocked(ap) + require.NotNil(t, oldest) + require.Equal(t, idleOld.id, oldest.id) + + inflight, waiters := accountPoolLoadLocked(ap) + require.Equal(t, 1, inflight) + require.Equal(t, 2, waiters) + + pool.accounts.Store(accountID, ap) + loadInflight, loadWaiters, conns := pool.AccountPoolLoad(accountID) + require.Equal(t, 1, loadInflight) + require.Equal(t, 2, loadWaiters) + require.Equal(t, 3, conns) + + zeroInflight, zeroWaiters, zeroConns := pool.AccountPoolLoad(0) + require.Equal(t, 0, zeroInflight) + require.Equal(t, 0, zeroWaiters) + require.Equal(t, 0, zeroConns) +} + +func TestOpenAIWSConnPool_Close_WaitsWorkerGroupAndNilStopChannel(t *testing.T) { + pool := &openAIWSConnPool{} + release := make(chan struct{}) + pool.workerWg.Add(1) + go func() { + defer pool.workerWg.Done() + <-release + }() + + closed := make(chan struct{}) + go func() { + pool.Close() + close(closed) + }() + + select { + case <-closed: + t.Fatal("Close 不应在 WaitGroup 未完成时提前返回") + case <-time.After(30 * time.Millisecond): + } + + close(release) + select { + case <-closed: + case <-time.After(time.Second): + t.Fatal("Close 未等待 workerWg 完成") + } +} + +func TestOpenAIWSConnPool_Close_ClosesOnlyIdleConnections(t *testing.T) { + pool := &openAIWSConnPool{ + workerStopCh: make(chan struct{}), + } + + accountID := int64(606) + ap := &openAIWSAccountPool{ + conns: map[string]*openAIWSConn{}, + } + idle := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil) + leased := newOpenAIWSConn("leased_conn", accountID, &openAIWSFakeConn{}, nil) + require.True(t, leased.tryAcquire()) + + ap.conns[idle.id] = idle + ap.conns[leased.id] = leased + pool.accounts.Store(accountID, ap) + pool.accounts.Store("invalid-key", "invalid-value") + + pool.Close() + + select { + case <-idle.closedCh: + // idle should be closed + default: + t.Fatal("空闲连接应在 Close 时被关闭") + } + + select { + case <-leased.closedCh: + t.Fatal("已租赁连接不应在 Close 时被关闭") + default: + } + + leased.release() + pool.Close() +} + +func TestOpenAIWSConnPool_RunBackgroundPingSweep_ConcurrencyLimit(t *testing.T) { + cfg := &config.Config{} + pool := newOpenAIWSConnPool(cfg) + accountID := int64(505) + ap := pool.getOrCreateAccountPool(accountID) + + var current atomic.Int32 + var maxConcurrent atomic.Int32 + release := make(chan struct{}) + for i := 0; i < 25; i++ { + conn := newOpenAIWSConn(pool.nextConnID(accountID), accountID, &openAIWSPingBlockingConn{ + current: ¤t, + maxConcurrent: &maxConcurrent, + release: release, + }, nil) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.mu.Unlock() + } + + done := make(chan struct{}) + go func() { + pool.runBackgroundPingSweep() + close(done) + }() + + require.Eventually(t, func() bool { + return maxConcurrent.Load() >= 10 + }, time.Second, 10*time.Millisecond) + + close(release) + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("runBackgroundPingSweep 未在释放后完成") + } + + require.LessOrEqual(t, maxConcurrent.Load(), int32(10)) +} + +func TestOpenAIWSConnLease_BasicGetterBranches(t *testing.T) { + var nilLease *openAIWSConnLease + require.Equal(t, "", nilLease.ConnID()) + require.Equal(t, time.Duration(0), nilLease.QueueWaitDuration()) + require.Equal(t, time.Duration(0), nilLease.ConnPickDuration()) + require.False(t, nilLease.Reused()) + require.Equal(t, "", nilLease.HandshakeHeader("x-test")) + require.False(t, nilLease.IsPrewarmed()) + nilLease.MarkPrewarmed() + nilLease.Release() + + conn := newOpenAIWSConn("getter_conn", 1, &openAIWSFakeConn{}, http.Header{"X-Test": []string{"ok"}}) + lease := &openAIWSConnLease{ + conn: conn, + queueWait: 3 * time.Millisecond, + connPick: 4 * time.Millisecond, + reused: true, + } + require.Equal(t, "getter_conn", lease.ConnID()) + require.Equal(t, 3*time.Millisecond, lease.QueueWaitDuration()) + require.Equal(t, 4*time.Millisecond, lease.ConnPickDuration()) + require.True(t, lease.Reused()) + require.Equal(t, "ok", lease.HandshakeHeader("x-test")) + require.False(t, lease.IsPrewarmed()) + lease.MarkPrewarmed() + require.True(t, lease.IsPrewarmed()) + lease.Release() +} + +func TestOpenAIWSConnPool_UtilityBranches(t *testing.T) { + var nilPool *openAIWSConnPool + require.Equal(t, OpenAIWSPoolMetricsSnapshot{}, nilPool.SnapshotMetrics()) + require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, nilPool.SnapshotTransportMetrics()) + + pool := &openAIWSConnPool{cfg: &config.Config{}} + pool.metrics.acquireTotal.Store(7) + pool.metrics.acquireReuseTotal.Store(3) + metrics := pool.SnapshotMetrics() + require.Equal(t, int64(7), metrics.AcquireTotal) + require.Equal(t, int64(3), metrics.AcquireReuseTotal) + + // 非 transport metrics dialer 路径 + pool.clientDialer = &openAIWSFakeDialer{} + require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, pool.SnapshotTransportMetrics()) + pool.setClientDialerForTest(nil) + require.NotNil(t, pool.clientDialer) + + require.Equal(t, 8, nilPool.maxConnsHardCap()) + require.False(t, nilPool.dynamicMaxConnsEnabled()) + require.Equal(t, 1.0, nilPool.maxConnsFactorByAccount(nil)) + require.Equal(t, 0, nilPool.minIdlePerAccount()) + require.Equal(t, 4, nilPool.maxIdlePerAccount()) + require.Equal(t, 256, nilPool.queueLimitPerConn()) + require.Equal(t, 0.7, nilPool.targetUtilization()) + require.Equal(t, time.Duration(0), nilPool.prewarmCooldown()) + require.Equal(t, 10*time.Second, nilPool.dialTimeout()) + + // shouldSuppressPrewarmLocked 覆盖 3 条分支 + now := time.Now() + apNilFail := &openAIWSAccountPool{prewarmFails: 1} + require.False(t, pool.shouldSuppressPrewarmLocked(apNilFail, now)) + apZeroTime := &openAIWSAccountPool{prewarmFails: 2} + require.False(t, pool.shouldSuppressPrewarmLocked(apZeroTime, now)) + require.Equal(t, 0, apZeroTime.prewarmFails) + apOldFail := &openAIWSAccountPool{prewarmFails: 2, prewarmFailAt: now.Add(-openAIWSPrewarmFailureWindow - time.Second)} + require.False(t, pool.shouldSuppressPrewarmLocked(apOldFail, now)) + apRecentFail := &openAIWSAccountPool{prewarmFails: openAIWSPrewarmFailureSuppress, prewarmFailAt: now} + require.True(t, pool.shouldSuppressPrewarmLocked(apRecentFail, now)) + + // recordConnPickDuration 的保护分支 + nilPool.recordConnPickDuration(10 * time.Millisecond) + pool.recordConnPickDuration(-10 * time.Millisecond) + require.Equal(t, int64(1), pool.metrics.connPickTotal.Load()) + + // account pool 读写分支 + require.Nil(t, nilPool.getOrCreateAccountPool(1)) + require.Nil(t, pool.getOrCreateAccountPool(0)) + pool.accounts.Store(int64(7), "invalid") + ap := pool.getOrCreateAccountPool(7) + require.NotNil(t, ap) + _, ok := pool.getAccountPool(0) + require.False(t, ok) + _, ok = pool.getAccountPool(12345) + require.False(t, ok) + pool.accounts.Store(int64(8), "bad-type") + _, ok = pool.getAccountPool(8) + require.False(t, ok) + + // health check 条件 + require.False(t, pool.shouldHealthCheckConn(nil)) + conn := newOpenAIWSConn("health", 1, &openAIWSFakeConn{}, nil) + conn.lastUsedNano.Store(time.Now().Add(-openAIWSConnHealthCheckIdle - time.Second).UnixNano()) + require.True(t, pool.shouldHealthCheckConn(conn)) +} + +func TestOpenAIWSConn_LeaseAndTimeHelpers_NilAndClosedBranches(t *testing.T) { + var nilConn *openAIWSConn + nilConn.touch() + require.Equal(t, time.Time{}, nilConn.createdAt()) + require.Equal(t, time.Time{}, nilConn.lastUsedAt()) + require.Equal(t, time.Duration(0), nilConn.idleDuration(time.Now())) + require.Equal(t, time.Duration(0), nilConn.age(time.Now())) + require.False(t, nilConn.isLeased()) + require.False(t, nilConn.isPrewarmed()) + nilConn.markPrewarmed() + + conn := newOpenAIWSConn("lease_state", 1, &openAIWSFakeConn{}, nil) + require.True(t, conn.tryAcquire()) + require.True(t, conn.isLeased()) + conn.release() + require.False(t, conn.isLeased()) + conn.close() + require.False(t, conn.tryAcquire()) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := conn.acquire(ctx) + require.Error(t, err) +} + +func TestOpenAIWSConnLease_ReadWriteNilConnBranches(t *testing.T) { + lease := &openAIWSConnLease{} + require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed) + require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed) + _, err := lease.ReadMessage(10 * time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = lease.ReadMessageContext(context.Background()) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestOpenAIWSConnLease_ReleasedLeaseGuards(t *testing.T) { + conn := newOpenAIWSConn("released_guard", 1, &openAIWSFakeConn{}, nil) + lease := &openAIWSConnLease{conn: conn} + + require.NoError(t, lease.PingWithTimeout(50*time.Millisecond)) + + lease.Release() + lease.Release() // idempotent + + require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed) + require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed) + require.ErrorIs(t, lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed) + + _, err := lease.ReadMessage(10 * time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = lease.ReadMessageContext(context.Background()) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + + require.ErrorIs(t, lease.PingWithTimeout(50*time.Millisecond), errOpenAIWSConnClosed) +} + +func TestOpenAIWSConnLease_MarkBrokenAfterRelease_NoEviction(t *testing.T) { + conn := newOpenAIWSConn("released_markbroken", 7, &openAIWSFakeConn{}, nil) + ap := &openAIWSAccountPool{ + conns: map[string]*openAIWSConn{ + conn.id: conn, + }, + } + pool := &openAIWSConnPool{} + pool.accounts.Store(int64(7), ap) + + lease := &openAIWSConnLease{ + pool: pool, + accountID: 7, + conn: conn, + } + + lease.Release() + lease.MarkBroken() + + ap.mu.Lock() + _, exists := ap.conns[conn.id] + ap.mu.Unlock() + require.True(t, exists, "released lease should not evict active pool connection") +} + +func TestOpenAIWSConn_AdditionalGuardBranches(t *testing.T) { + var nilConn *openAIWSConn + require.False(t, nilConn.tryAcquire()) + require.ErrorIs(t, nilConn.acquire(context.Background()), errOpenAIWSConnClosed) + nilConn.release() + nilConn.close() + require.Equal(t, "", nilConn.handshakeHeader("x-test")) + + connBusy := newOpenAIWSConn("busy_ctx", 1, &openAIWSFakeConn{}, nil) + require.True(t, connBusy.tryAcquire()) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + require.ErrorIs(t, connBusy.acquire(ctx), context.Canceled) + connBusy.release() + + connClosed := newOpenAIWSConn("closed_guard", 1, &openAIWSFakeConn{}, nil) + connClosed.close() + require.ErrorIs( + t, + connClosed.writeJSONWithTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), + errOpenAIWSConnClosed, + ) + _, err := connClosed.readMessageWithContextTimeout(context.Background(), time.Second) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + require.ErrorIs(t, connClosed.pingWithTimeout(time.Second), errOpenAIWSConnClosed) + + connNoWS := newOpenAIWSConn("no_ws", 1, nil, nil) + require.ErrorIs(t, connNoWS.writeJSON(map[string]any{"k": "v"}, context.Background()), errOpenAIWSConnClosed) + _, err = connNoWS.readMessage(context.Background()) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + require.ErrorIs(t, connNoWS.pingWithTimeout(time.Second), errOpenAIWSConnClosed) + require.Equal(t, "", connNoWS.handshakeHeader("x-test")) + + connOK := newOpenAIWSConn("ok", 1, &openAIWSFakeConn{}, nil) + require.NoError(t, connOK.writeJSON(map[string]any{"k": "v"}, nil)) + _, err = connOK.readMessageWithContextTimeout(context.Background(), 0) + require.NoError(t, err) + require.NoError(t, connOK.pingWithTimeout(0)) + + connZero := newOpenAIWSConn("zero_ts", 1, &openAIWSFakeConn{}, nil) + connZero.createdAtNano.Store(0) + connZero.lastUsedNano.Store(0) + require.True(t, connZero.createdAt().IsZero()) + require.True(t, connZero.lastUsedAt().IsZero()) + require.Equal(t, time.Duration(0), connZero.idleDuration(time.Now())) + require.Equal(t, time.Duration(0), connZero.age(time.Now())) + + require.Nil(t, cloneOpenAIWSAcquireRequestPtr(nil)) + copied := cloneHeader(http.Header{ + "X-Empty": []string{}, + "X-Test": []string{"v1"}, + }) + require.Contains(t, copied, "X-Empty") + require.Nil(t, copied["X-Empty"]) + require.Equal(t, "v1", copied.Get("X-Test")) + + closeOpenAIWSConns([]*openAIWSConn{nil, connOK}) +} + +func TestOpenAIWSConnLease_MarkBrokenEvictsConn(t *testing.T) { + pool := newOpenAIWSConnPool(&config.Config{}) + accountID := int64(5001) + conn := newOpenAIWSConn("broken_me", accountID, &openAIWSFakeConn{}, nil) + ap := pool.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.mu.Unlock() + + lease := &openAIWSConnLease{ + pool: pool, + accountID: accountID, + conn: conn, + } + lease.MarkBroken() + + ap.mu.Lock() + _, exists := ap.conns[conn.id] + ap.mu.Unlock() + require.False(t, exists) + require.False(t, conn.tryAcquire(), "被标记为 broken 的连接应被关闭") +} + +func TestOpenAIWSConnPool_TargetConnCountAndPrewarmBranches(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + pool := newOpenAIWSConnPool(cfg) + + require.Equal(t, 0, pool.targetConnCountLocked(nil, 1)) + ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}} + require.Equal(t, 0, pool.targetConnCountLocked(ap, 0)) + + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 3 + require.Equal(t, 1, pool.targetConnCountLocked(ap, 1), "minIdle 应被 maxConns 截断") + + // 覆盖 waiters>0 且 target 需要至少 len(conns)+1 的分支 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.9 + busy := newOpenAIWSConn("busy_target", 2, &openAIWSFakeConn{}, nil) + require.True(t, busy.tryAcquire()) + busy.waiters.Store(1) + ap.conns[busy.id] = busy + target := pool.targetConnCountLocked(ap, 4) + require.GreaterOrEqual(t, target, len(ap.conns)+1) + + // prewarm: account pool 缺失时,拨号后的连接应被关闭并提前返回 + req := openAIWSAcquireRequest{ + Account: &Account{ID: 999, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}, + WSURL: "wss://example.com/v1/responses", + } + pool.prewarmConns(999, req, 1) + + // prewarm: 拨号失败分支(prewarmFails 累加) + accountID := int64(1000) + failPool := newOpenAIWSConnPool(cfg) + failPool.setClientDialerForTest(&openAIWSAlwaysFailDialer{}) + apFail := failPool.getOrCreateAccountPool(accountID) + apFail.mu.Lock() + apFail.creating = 1 + apFail.mu.Unlock() + req.Account.ID = accountID + failPool.prewarmConns(accountID, req, 1) + apFail.mu.Lock() + require.GreaterOrEqual(t, apFail.prewarmFails, 1) + apFail.mu.Unlock() +} + +func TestOpenAIWSConnPool_Acquire_ErrorBranches(t *testing.T) { + var nilPool *openAIWSConnPool + _, err := nilPool.Acquire(context.Background(), openAIWSAcquireRequest{}) + require.Error(t, err) + + pool := newOpenAIWSConnPool(&config.Config{}) + _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: &Account{ID: 1}, + WSURL: " ", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "ws url is empty") + + // target=nil 分支:池满且仅有 nil 连接 + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1 + fullPool := newOpenAIWSConnPool(cfg) + account := &Account{ID: 2001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := fullPool.getOrCreateAccountPool(account.ID) + ap.mu.Lock() + ap.conns["nil"] = nil + ap.lastCleanupAt = time.Now() + ap.mu.Unlock() + _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + + // queue full 分支:waiters 达上限 + account2 := &Account{ID: 2002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap2 := fullPool.getOrCreateAccountPool(account2.ID) + conn := newOpenAIWSConn("queue_full", account2.ID, &openAIWSFakeConn{}, nil) + require.True(t, conn.tryAcquire()) + conn.waiters.Store(1) + ap2.mu.Lock() + ap2.conns[conn.id] = conn + ap2.lastCleanupAt = time.Now() + ap2.mu.Unlock() + _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account2, + WSURL: "wss://example.com/v1/responses", + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull) +} + +type openAIWSFakeDialer struct{} + +func (d *openAIWSFakeDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + return &openAIWSFakeConn{}, 0, nil, nil +} + +type openAIWSCountingDialer struct { + mu sync.Mutex + dialCount int +} + +type openAIWSAlwaysFailDialer struct { + mu sync.Mutex + dialCount int +} + +type openAIWSPingBlockingConn struct { + current *atomic.Int32 + maxConcurrent *atomic.Int32 + release <-chan struct{} +} + +func (c *openAIWSPingBlockingConn) WriteJSON(context.Context, any) error { + return nil +} + +func (c *openAIWSPingBlockingConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_blocking_ping"}}`), nil +} + +func (c *openAIWSPingBlockingConn) Ping(ctx context.Context) error { + if c.current == nil || c.maxConcurrent == nil { + return nil + } + + now := c.current.Add(1) + for { + prev := c.maxConcurrent.Load() + if now <= prev || c.maxConcurrent.CompareAndSwap(prev, now) { + break + } + } + defer c.current.Add(-1) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.release: + return nil + } +} + +func (c *openAIWSPingBlockingConn) Close() error { + return nil +} + +func (d *openAIWSCountingDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + d.dialCount++ + d.mu.Unlock() + return &openAIWSFakeConn{}, 0, nil, nil +} + +func (d *openAIWSCountingDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +func (d *openAIWSAlwaysFailDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + d.dialCount++ + d.mu.Unlock() + return nil, 503, nil, errors.New("dial failed") +} + +func (d *openAIWSAlwaysFailDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSFakeConn struct { + mu sync.Mutex + closed bool + payload [][]byte +} + +func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return errors.New("closed") + } + c.payload = append(c.payload, []byte("ok")) + _ = value + return nil +} + +func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil, errors.New("closed") + } + return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil +} + +func (c *openAIWSFakeConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSFakeConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +type openAIWSBlockingConn struct { + readDelay time.Duration +} + +func (c *openAIWSBlockingConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + _ = value + return nil +} + +func (c *openAIWSBlockingConn) ReadMessage(ctx context.Context) ([]byte, error) { + delay := c.readDelay + if delay <= 0 { + delay = 10 * time.Millisecond + } + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + return []byte(`{"type":"response.completed","response":{"id":"resp_blocking"}}`), nil + } +} + +func (c *openAIWSBlockingConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSBlockingConn) Close() error { + return nil +} + +type openAIWSWriteBlockingConn struct{} + +func (c *openAIWSWriteBlockingConn) WriteJSON(ctx context.Context, _ any) error { + <-ctx.Done() + return ctx.Err() +} + +func (c *openAIWSWriteBlockingConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_write_block"}}`), nil +} + +func (c *openAIWSWriteBlockingConn) Ping(context.Context) error { + return nil +} + +func (c *openAIWSWriteBlockingConn) Close() error { + return nil +} + +type openAIWSPingFailConn struct{} + +func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error { + return nil +} + +func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil +} + +func (c *openAIWSPingFailConn) Ping(context.Context) error { + return errors.New("ping failed") +} + +func (c *openAIWSPingFailConn) Close() error { + return nil +} + +type openAIWSContextProbeConn struct { + lastWriteCtx context.Context +} + +func (c *openAIWSContextProbeConn) WriteJSON(ctx context.Context, _ any) error { + c.lastWriteCtx = ctx + return nil +} + +func (c *openAIWSContextProbeConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_ctx_probe"}}`), nil +} + +func (c *openAIWSContextProbeConn) Ping(context.Context) error { + return nil +} + +func (c *openAIWSContextProbeConn) Close() error { + return nil +} + +type openAIWSNilConnDialer struct{} + +func (d *openAIWSNilConnDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + return nil, 200, nil, nil +} + +func TestOpenAIWSConnPool_DialConnNilConnection(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(&openAIWSNilConnDialer{}) + account := &Account{ID: 91, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "nil connection") +} + +func TestOpenAIWSConnPool_SnapshotTransportMetrics(t *testing.T) { + cfg := &config.Config{} + pool := newOpenAIWSConnPool(cfg) + + dialer, ok := pool.clientDialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + _, err := dialer.proxyHTTPClient("http://127.0.0.1:28080") + require.NoError(t, err) + _, err = dialer.proxyHTTPClient("http://127.0.0.1:28080") + require.NoError(t, err) + _, err = dialer.proxyHTTPClient("http://127.0.0.1:28081") + require.NoError(t, err) + + snapshot := pool.SnapshotTransportMetrics() + require.Equal(t, int64(1), snapshot.ProxyClientCacheHits) + require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses) + require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001) +} diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go new file mode 100644 index 00000000..df4d4871 --- /dev/null +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -0,0 +1,1218 @@ +package service + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下失败时不应回退 HTTP") +} + +func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 101, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_keep","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发") + require.NotNil(t, upstream.lastReq, "HTTP 入站应命中 HTTP 上游") + require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists(), "HTTP 路径应沿用原逻辑移除 previous_response_id") + + decision, _ := c.Get("openai_ws_transport_decision") + reason, _ := c.Get("openai_ws_transport_reason") + require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision) + require.Equal(t, "client_protocol_http", reason) +} + +func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = false + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUpgradeRequired) + _, _ = w.Write([]byte(`upgrade required`)) + })) + defer ws426Server.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":8,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 12, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": ws426Server.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_426","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "upgrade_required") + require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP") + require.Equal(t, http.StatusUpgradeRequired, rec.Code) + require.Contains(t, rec.Body.String(), "426") +} + +func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) { + gin.SetMode(gin.TestMode) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":2,"output_tokens":3,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 30 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 21, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + svc.markOpenAIWSFallbackCooling(account.ID, "upgrade_required") + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_cooling","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP") + + _, ok := c.Get("openai_ws_fallback_cooling") + require.False(t, ok, "已移除 fallback cooling 快捷回退路径") +} + +func TestOpenAIGatewayService_Forward_ReturnErrorWhenOnlyWSv1Enabled(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsockets = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 31, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1/responses", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_v1","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "ws v1") + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "WSv1") + require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求") +} + +func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { + cfg := &config.Config{} + svc := NewOpenAIGatewayService( + nil, + nil, + nil, + nil, + nil, + cfg, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_missing", decision.Reason) +} + +func TestOpenAIGatewayService_Forward_WSv2FallbackWhenResponseAlreadyWrittenReturnsWSError(t *testing.T) { + gin.SetMode(gin.TestMode) + ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUpgradeRequired) + _, _ = w.Write([]byte(`upgrade required`)) + })) + defer ws426Server.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + c.String(http.StatusAccepted, "already-written") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 41, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": ws426Server.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "ws fallback") + require.Nil(t, upstream.lastReq, "已写下游响应时,不应再回退 HTTP") +} + +func TestOpenAIGatewayService_Forward_WSv2StreamEarlyCloseFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + + // 仅发送 response.created(非 token 事件)后立即关闭, + // 模拟线上“上游早期内部错误断连”的场景。 + if err := conn.WriteJSON(map[string]any{ + "type": "response.created", + "response": map[string]any{ + "id": "resp_ws_created_only", + "model": "gpt-5.3-codex", + }, + }); err != nil { + t.Errorf("write response.created failed: %v", err) + return + } + closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "") + _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 88, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 早期断连后不应再回退 HTTP") + require.Empty(t, rec.Body.String(), "未产出 token 前上游断连时不应写入下游半截流") +} + +func TestOpenAIGatewayService_Forward_WSv2RetryFiveTimesThenFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "") + _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_retry_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 89, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 重连耗尽后不应再回退 HTTP") + require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load()) +} + +func TestOpenAIGatewayService_Forward_WSv2PolicyViolationFastFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + closePayload := websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "") + _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_policy_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 1 + cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 2 + cfg.Gateway.OpenAIWS.RetryJitterRatio = 0 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 8901, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "策略违规关闭后不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "策略违规不应进行 WS 重试") +} + +func TestOpenAIGatewayService_Forward_WSv2ConnectionLimitReachedRetryThenFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "websocket_connection_limit_reached", + "type": "server_error", + "message": "websocket connection limit reached", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_retry_limit","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 90, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "触发 websocket_connection_limit_reached 后不应回退 HTTP") + require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load()) +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundRecoversByDroppingPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt := wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + if attempt == 1 { + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_ws_prev_recover_ok", + "model": "gpt-5.3-codex", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + "input_tokens_details": map[string]any{ + "cached_tokens": 0, + }, + }, + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 91, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_ws_prev_recover_ok", result.RequestID) + require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(2), wsAttempts.Load(), "previous_response_not_found 应触发一次去掉 previous_response_id 的恢复重试") + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "resp_ws_prev_recover_ok", gjson.Get(rec.Body.String(), "id").String()) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 2) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id") + require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryForFunctionCallOutput(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 92, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "function_call_output 场景应跳过 previous_response_not_found 自动恢复") + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, strings.ToLower(rec.Body.String()), "previous response not found") + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 1) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryWithoutPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 93, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "缺少 previous_response_id 时应跳过自动恢复重试") + require.Equal(t, http.StatusBadRequest, rec.Code) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 1) + require.False(t, gjson.GetBytes(requests[0], "previous_response_id").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOnce(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 94, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(2), wsAttempts.Load(), "应只允许一次自动恢复重试") + require.Equal(t, http.StatusBadRequest, rec.Code) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 2) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id") + require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") +} diff --git a/backend/internal/service/openai_ws_protocol_resolver.go b/backend/internal/service/openai_ws_protocol_resolver.go new file mode 100644 index 00000000..368643be --- /dev/null +++ b/backend/internal/service/openai_ws_protocol_resolver.go @@ -0,0 +1,117 @@ +package service + +import "github.com/Wei-Shaw/sub2api/internal/config" + +// OpenAIUpstreamTransport 表示 OpenAI 上游传输协议。 +type OpenAIUpstreamTransport string + +const ( + OpenAIUpstreamTransportAny OpenAIUpstreamTransport = "" + OpenAIUpstreamTransportHTTPSSE OpenAIUpstreamTransport = "http_sse" + OpenAIUpstreamTransportResponsesWebsocket OpenAIUpstreamTransport = "responses_websockets" + OpenAIUpstreamTransportResponsesWebsocketV2 OpenAIUpstreamTransport = "responses_websockets_v2" +) + +// OpenAIWSProtocolDecision 表示协议决策结果。 +type OpenAIWSProtocolDecision struct { + Transport OpenAIUpstreamTransport + Reason string +} + +// OpenAIWSProtocolResolver 定义 OpenAI 上游协议决策。 +type OpenAIWSProtocolResolver interface { + Resolve(account *Account) OpenAIWSProtocolDecision +} + +type defaultOpenAIWSProtocolResolver struct { + cfg *config.Config +} + +// NewOpenAIWSProtocolResolver 创建默认协议决策器。 +func NewOpenAIWSProtocolResolver(cfg *config.Config) OpenAIWSProtocolResolver { + return &defaultOpenAIWSProtocolResolver{cfg: cfg} +} + +func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProtocolDecision { + if account == nil { + return openAIWSHTTPDecision("account_missing") + } + if !account.IsOpenAI() { + return openAIWSHTTPDecision("platform_not_openai") + } + if account.IsOpenAIWSForceHTTPEnabled() { + return openAIWSHTTPDecision("account_force_http") + } + if r == nil || r.cfg == nil { + return openAIWSHTTPDecision("config_missing") + } + + wsCfg := r.cfg.Gateway.OpenAIWS + if wsCfg.ForceHTTP { + return openAIWSHTTPDecision("global_force_http") + } + if !wsCfg.Enabled { + return openAIWSHTTPDecision("global_disabled") + } + if account.IsOpenAIOAuth() { + if !wsCfg.OAuthEnabled { + return openAIWSHTTPDecision("oauth_disabled") + } + } else if account.IsOpenAIApiKey() { + if !wsCfg.APIKeyEnabled { + return openAIWSHTTPDecision("apikey_disabled") + } + } else { + return openAIWSHTTPDecision("unknown_auth_type") + } + if wsCfg.ModeRouterV2Enabled { + mode := account.ResolveOpenAIResponsesWebSocketV2Mode(wsCfg.IngressModeDefault) + switch mode { + case OpenAIWSIngressModeOff: + return openAIWSHTTPDecision("account_mode_off") + case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + // continue + default: + return openAIWSHTTPDecision("account_mode_off") + } + if account.Concurrency <= 0 { + return openAIWSHTTPDecision("account_concurrency_invalid") + } + if wsCfg.ResponsesWebsocketsV2 { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_mode_" + mode, + } + } + if wsCfg.ResponsesWebsockets { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocket, + Reason: "ws_v1_mode_" + mode, + } + } + return openAIWSHTTPDecision("feature_disabled") + } + if !account.IsOpenAIResponsesWebSocketV2Enabled() { + return openAIWSHTTPDecision("account_disabled") + } + if wsCfg.ResponsesWebsocketsV2 { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_enabled", + } + } + if wsCfg.ResponsesWebsockets { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocket, + Reason: "ws_v1_enabled", + } + } + return openAIWSHTTPDecision("feature_disabled") +} + +func openAIWSHTTPDecision(reason string) OpenAIWSProtocolDecision { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportHTTPSSE, + Reason: reason, + } +} diff --git a/backend/internal/service/openai_ws_protocol_resolver_test.go b/backend/internal/service/openai_ws_protocol_resolver_test.go new file mode 100644 index 00000000..5be76e28 --- /dev/null +++ b/backend/internal/service/openai_ws_protocol_resolver_test.go @@ -0,0 +1,203 @@ +package service + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSProtocolResolver_Resolve(t *testing.T) { + baseCfg := &config.Config{} + baseCfg.Gateway.OpenAIWS.Enabled = true + baseCfg.Gateway.OpenAIWS.OAuthEnabled = true + baseCfg.Gateway.OpenAIWS.APIKeyEnabled = true + baseCfg.Gateway.OpenAIWS.ResponsesWebsockets = false + baseCfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + openAIOAuthEnabled := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + }, + } + + t.Run("v2优先", func(t *testing.T) { + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_enabled", decision.Reason) + }) + + t.Run("v2关闭时回退v1", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + cfg.Gateway.OpenAIWS.ResponsesWebsockets = true + + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocket, decision.Transport) + require.Equal(t, "ws_v1_enabled", decision.Reason) + }) + + t.Run("透传开关不影响WS协议判定", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "openai_passthrough": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_enabled", decision.Reason) + }) + + t.Run("账号级强制HTTP", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "openai_ws_force_http": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_force_http", decision.Reason) + }) + + t.Run("全局关闭保持HTTP", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.Enabled = false + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "global_disabled", decision.Reason) + }) + + t.Run("账号开关关闭保持HTTP", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": false, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_disabled", decision.Reason) + }) + + t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_disabled", decision.Reason) + }) + + t.Run("兼容旧键openai_ws_enabled", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_ws_enabled": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_enabled", decision.Reason) + }) + + t.Run("按账号类型开关控制", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.OAuthEnabled = false + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "oauth_disabled", decision.Reason) + }) + + t.Run("API Key 账号关闭开关时回退HTTP", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.APIKeyEnabled = false + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "apikey_disabled", decision.Reason) + }) + + t.Run("未知认证类型回退HTTP", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: "unknown_type", + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "unknown_auth_type", decision.Reason) + }) +} + +func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + + t.Run("dedicated mode routes to ws v2", func(t *testing.T) { + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_dedicated", decision.Reason) + }) + + t.Run("off mode routes to http", func(t *testing.T) { + offAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(offAccount) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_mode_off", decision.Reason) + }) + + t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) { + legacyAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_shared", decision.Reason) + }) + + t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) { + invalidConcurrency := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_concurrency_invalid", decision.Reason) + }) +} diff --git a/backend/internal/service/openai_ws_state_store.go b/backend/internal/service/openai_ws_state_store.go new file mode 100644 index 00000000..b606baa1 --- /dev/null +++ b/backend/internal/service/openai_ws_state_store.go @@ -0,0 +1,440 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" +) + +const ( + openAIWSResponseAccountCachePrefix = "openai:response:" + openAIWSStateStoreCleanupInterval = time.Minute + openAIWSStateStoreCleanupMaxPerMap = 512 + openAIWSStateStoreMaxEntriesPerMap = 65536 + openAIWSStateStoreRedisTimeout = 3 * time.Second +) + +type openAIWSAccountBinding struct { + accountID int64 + expiresAt time.Time +} + +type openAIWSConnBinding struct { + connID string + expiresAt time.Time +} + +type openAIWSTurnStateBinding struct { + turnState string + expiresAt time.Time +} + +type openAIWSSessionConnBinding struct { + connID string + expiresAt time.Time +} + +// OpenAIWSStateStore 管理 WSv2 的粘连状态。 +// - response_id -> account_id 用于续链路由 +// - response_id -> conn_id 用于连接内上下文复用 +// +// response_id -> account_id 优先走 GatewayCache(Redis),同时维护本地热缓存。 +// response_id -> conn_id 仅在本进程内有效。 +type OpenAIWSStateStore interface { + BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error + GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) + DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error + + BindResponseConn(responseID, connID string, ttl time.Duration) + GetResponseConn(responseID string) (string, bool) + DeleteResponseConn(responseID string) + + BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) + GetSessionTurnState(groupID int64, sessionHash string) (string, bool) + DeleteSessionTurnState(groupID int64, sessionHash string) + + BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) + GetSessionConn(groupID int64, sessionHash string) (string, bool) + DeleteSessionConn(groupID int64, sessionHash string) +} + +type defaultOpenAIWSStateStore struct { + cache GatewayCache + + responseToAccountMu sync.RWMutex + responseToAccount map[string]openAIWSAccountBinding + responseToConnMu sync.RWMutex + responseToConn map[string]openAIWSConnBinding + sessionToTurnStateMu sync.RWMutex + sessionToTurnState map[string]openAIWSTurnStateBinding + sessionToConnMu sync.RWMutex + sessionToConn map[string]openAIWSSessionConnBinding + + lastCleanupUnixNano atomic.Int64 +} + +// NewOpenAIWSStateStore 创建默认 WS 状态存储。 +func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore { + store := &defaultOpenAIWSStateStore{ + cache: cache, + responseToAccount: make(map[string]openAIWSAccountBinding, 256), + responseToConn: make(map[string]openAIWSConnBinding, 256), + sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256), + sessionToConn: make(map[string]openAIWSSessionConnBinding, 256), + } + store.lastCleanupUnixNano.Store(time.Now().UnixNano()) + return store +} + +func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" || accountID <= 0 { + return nil + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + expiresAt := time.Now().Add(ttl) + s.responseToAccountMu.Lock() + ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap) + s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt} + s.responseToAccountMu.Unlock() + + if s.cache == nil { + return nil + } + cacheKey := openAIWSResponseAccountCacheKey(id) + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + defer cancel() + return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl) +} + +func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return 0, nil + } + s.maybeCleanup() + + now := time.Now() + s.responseToAccountMu.RLock() + if binding, ok := s.responseToAccount[id]; ok { + if now.Before(binding.expiresAt) { + accountID := binding.accountID + s.responseToAccountMu.RUnlock() + return accountID, nil + } + } + s.responseToAccountMu.RUnlock() + + if s.cache == nil { + return 0, nil + } + + cacheKey := openAIWSResponseAccountCacheKey(id) + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + defer cancel() + accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey) + if err != nil || accountID <= 0 { + // 缓存读取失败不阻断主流程,按未命中降级。 + return 0, nil + } + return accountID, nil +} + +func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return nil + } + s.responseToAccountMu.Lock() + delete(s.responseToAccount, id) + s.responseToAccountMu.Unlock() + + if s.cache == nil { + return nil + } + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + defer cancel() + return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id)) +} + +func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) { + id := normalizeOpenAIWSResponseID(responseID) + conn := strings.TrimSpace(connID) + if id == "" || conn == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.responseToConnMu.Lock() + ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap) + s.responseToConn[id] = openAIWSConnBinding{ + connID: conn, + expiresAt: time.Now().Add(ttl), + } + s.responseToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return "", false + } + s.maybeCleanup() + + now := time.Now() + s.responseToConnMu.RLock() + binding, ok := s.responseToConn[id] + s.responseToConnMu.RUnlock() + if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" { + return "", false + } + return binding.connID, true +} + +func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return + } + s.responseToConnMu.Lock() + delete(s.responseToConn, id) + s.responseToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + state := strings.TrimSpace(turnState) + if key == "" || state == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.sessionToTurnStateMu.Lock() + ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap) + s.sessionToTurnState[key] = openAIWSTurnStateBinding{ + turnState: state, + expiresAt: time.Now().Add(ttl), + } + s.sessionToTurnStateMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return "", false + } + s.maybeCleanup() + + now := time.Now() + s.sessionToTurnStateMu.RLock() + binding, ok := s.sessionToTurnState[key] + s.sessionToTurnStateMu.RUnlock() + if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" { + return "", false + } + return binding.turnState, true +} + +func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return + } + s.sessionToTurnStateMu.Lock() + delete(s.sessionToTurnState, key) + s.sessionToTurnStateMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + conn := strings.TrimSpace(connID) + if key == "" || conn == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.sessionToConnMu.Lock() + ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap) + s.sessionToConn[key] = openAIWSSessionConnBinding{ + connID: conn, + expiresAt: time.Now().Add(ttl), + } + s.sessionToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return "", false + } + s.maybeCleanup() + + now := time.Now() + s.sessionToConnMu.RLock() + binding, ok := s.sessionToConn[key] + s.sessionToConnMu.RUnlock() + if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" { + return "", false + } + return binding.connID, true +} + +func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return + } + s.sessionToConnMu.Lock() + delete(s.sessionToConn, key) + s.sessionToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) maybeCleanup() { + if s == nil { + return + } + now := time.Now() + last := time.Unix(0, s.lastCleanupUnixNano.Load()) + if now.Sub(last) < openAIWSStateStoreCleanupInterval { + return + } + if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) { + return + } + + // 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。 + s.responseToAccountMu.Lock() + cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap) + s.responseToAccountMu.Unlock() + + s.responseToConnMu.Lock() + cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap) + s.responseToConnMu.Unlock() + + s.sessionToTurnStateMu.Lock() + cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap) + s.sessionToTurnStateMu.Unlock() + + s.sessionToConnMu.Lock() + cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap) + s.sessionToConnMu.Unlock() +} + +func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) { + if len(bindings) < maxEntries || maxEntries <= 0 { + return + } + if _, exists := bindings[incomingKey]; exists { + return + } + // 固定上限保护:淘汰任意一项,优先保证内存有界。 + for key := range bindings { + delete(bindings, key) + return + } +} + +func normalizeOpenAIWSResponseID(responseID string) string { + return strings.TrimSpace(responseID) +} + +func openAIWSResponseAccountCacheKey(responseID string) string { + sum := sha256.Sum256([]byte(responseID)) + return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:]) +} + +func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration { + if ttl <= 0 { + return time.Hour + } + return ttl +} + +func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string { + hash := strings.TrimSpace(sessionHash) + if hash == "" { + return "" + } + return fmt.Sprintf("%d:%s", groupID, hash) +} + +func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) { + if ctx == nil { + ctx = context.Background() + } + return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout) +} diff --git a/backend/internal/service/openai_ws_state_store_test.go b/backend/internal/service/openai_ws_state_store_test.go new file mode 100644 index 00000000..235d4233 --- /dev/null +++ b/backend/internal/service/openai_ws_state_store_test.go @@ -0,0 +1,235 @@ +package service + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSStateStore_BindGetDeleteResponseAccount(t *testing.T) { + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + ctx := context.Background() + groupID := int64(7) + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_abc", 101, time.Minute)) + + accountID, err := store.GetResponseAccount(ctx, groupID, "resp_abc") + require.NoError(t, err) + require.Equal(t, int64(101), accountID) + + require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_abc")) + accountID, err = store.GetResponseAccount(ctx, groupID, "resp_abc") + require.NoError(t, err) + require.Zero(t, accountID) +} + +func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindResponseConn("resp_conn", "conn_1", 30*time.Millisecond) + + connID, ok := store.GetResponseConn("resp_conn") + require.True(t, ok) + require.Equal(t, "conn_1", connID) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetResponseConn("resp_conn") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond) + + state, ok := store.GetSessionTurnState(9, "session_hash_1") + require.True(t, ok) + require.Equal(t, "turn_state_1", state) + + // group 隔离 + _, ok = store.GetSessionTurnState(10, "session_hash_1") + require.False(t, ok) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetSessionTurnState(9, "session_hash_1") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindSessionConn(9, "session_hash_conn_1", "conn_1", 30*time.Millisecond) + + connID, ok := store.GetSessionConn(9, "session_hash_conn_1") + require.True(t, ok) + require.Equal(t, "conn_1", connID) + + // group 隔离 + _, ok = store.GetSessionConn(10, "session_hash_conn_1") + require.False(t, ok) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetSessionConn(9, "session_hash_conn_1") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) { + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + store := NewOpenAIWSStateStore(cache) + ctx := context.Background() + groupID := int64(17) + responseID := "resp_cache_stale" + cacheKey := openAIWSResponseAccountCacheKey(responseID) + + cache.sessionBindings[cacheKey] = 501 + accountID, err := store.GetResponseAccount(ctx, groupID, responseID) + require.NoError(t, err) + require.Equal(t, int64(501), accountID) + + delete(cache.sessionBindings, cacheKey) + accountID, err = store.GetResponseAccount(ctx, groupID, responseID) + require.NoError(t, err) + require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射") +} + +func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) { + raw := NewOpenAIWSStateStore(nil) + store, ok := raw.(*defaultOpenAIWSStateStore) + require.True(t, ok) + + expiredAt := time.Now().Add(-time.Minute) + total := 2048 + store.responseToConnMu.Lock() + for i := 0; i < total; i++ { + store.responseToConn[fmt.Sprintf("resp_%d", i)] = openAIWSConnBinding{ + connID: "conn_incremental", + expiresAt: expiredAt, + } + } + store.responseToConnMu.Unlock() + + store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano()) + store.maybeCleanup() + + store.responseToConnMu.RLock() + remainingAfterFirst := len(store.responseToConn) + store.responseToConnMu.RUnlock() + require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展") + require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键") + + for i := 0; i < 8; i++ { + store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano()) + store.maybeCleanup() + } + + store.responseToConnMu.RLock() + remaining := len(store.responseToConn) + store.responseToConnMu.RUnlock() + require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键") +} + +func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) { + bindings := map[string]int{ + "a": 1, + "b": 2, + } + + ensureBindingCapacity(bindings, "c", 2) + bindings["c"] = 3 + + require.Len(t, bindings, 2) + require.Equal(t, 3, bindings["c"]) +} + +func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) { + bindings := map[string]int{ + "a": 1, + "b": 2, + } + + ensureBindingCapacity(bindings, "a", 2) + bindings["a"] = 9 + + require.Len(t, bindings, 2) + require.Equal(t, 9, bindings["a"]) +} + +type openAIWSStateStoreTimeoutProbeCache struct { + setHasDeadline bool + getHasDeadline bool + deleteHasDeadline bool + setDeadlineDelta time.Duration + getDeadlineDelta time.Duration + delDeadlineDelta time.Duration +} + +func (c *openAIWSStateStoreTimeoutProbeCache) GetSessionAccountID(ctx context.Context, _ int64, _ string) (int64, error) { + if deadline, ok := ctx.Deadline(); ok { + c.getHasDeadline = true + c.getDeadlineDelta = time.Until(deadline) + } + return 123, nil +} + +func (c *openAIWSStateStoreTimeoutProbeCache) SetSessionAccountID(ctx context.Context, _ int64, _ string, _ int64, _ time.Duration) error { + if deadline, ok := ctx.Deadline(); ok { + c.setHasDeadline = true + c.setDeadlineDelta = time.Until(deadline) + } + return errors.New("set failed") +} + +func (c *openAIWSStateStoreTimeoutProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error { + return nil +} + +func (c *openAIWSStateStoreTimeoutProbeCache) DeleteSessionAccountID(ctx context.Context, _ int64, _ string) error { + if deadline, ok := ctx.Deadline(); ok { + c.deleteHasDeadline = true + c.delDeadlineDelta = time.Until(deadline) + } + return nil +} + +func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) { + probe := &openAIWSStateStoreTimeoutProbeCache{} + store := NewOpenAIWSStateStore(probe) + ctx := context.Background() + groupID := int64(5) + + err := store.BindResponseAccount(ctx, groupID, "resp_timeout_probe", 11, time.Minute) + require.Error(t, err) + + accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe") + require.NoError(t, getErr) + require.Equal(t, int64(11), accountID, "本地缓存命中应优先返回已绑定账号") + + require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe")) + + require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文") + require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文") + require.False(t, probe.getHasDeadline, "GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取") + require.Greater(t, probe.setDeadlineDelta, 2*time.Second) + require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second) + require.Greater(t, probe.delDeadlineDelta, 2*time.Second) + require.LessOrEqual(t, probe.delDeadlineDelta, 3*time.Second) + + probe2 := &openAIWSStateStoreTimeoutProbeCache{} + store2 := NewOpenAIWSStateStore(probe2) + accountID2, err2 := store2.GetResponseAccount(ctx, groupID, "resp_cache_only") + require.NoError(t, err2) + require.Equal(t, int64(123), accountID2) + require.True(t, probe2.getHasDeadline, "GetSessionAccountID 在缓存未命中时应携带独立超时上下文") + require.Greater(t, probe2.getDeadlineDelta, 2*time.Second) + require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second) +} + +func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) { + ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + require.NotNil(t, ctx) + _, ok := ctx.Deadline() + require.True(t, ok, "应附加短超时") +} diff --git a/backend/internal/service/ops_account_availability.go b/backend/internal/service/ops_account_availability.go index a649e7b5..da66ec4d 100644 --- a/backend/internal/service/ops_account_availability.go +++ b/backend/internal/service/ops_account_availability.go @@ -66,7 +66,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi } isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched - scopeRateLimits := acc.GetAntigravityScopeRateLimits() if acc.Platform != "" { if _, ok := platform[acc.Platform]; !ok { @@ -85,14 +84,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi if hasError { p.ErrorCount++ } - if len(scopeRateLimits) > 0 { - if p.ScopeRateLimitCount == nil { - p.ScopeRateLimitCount = make(map[string]int64) - } - for scope := range scopeRateLimits { - p.ScopeRateLimitCount[scope]++ - } - } } for _, grp := range acc.Groups { @@ -117,14 +108,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi if hasError { g.ErrorCount++ } - if len(scopeRateLimits) > 0 { - if g.ScopeRateLimitCount == nil { - g.ScopeRateLimitCount = make(map[string]int64) - } - for scope := range scopeRateLimits { - g.ScopeRateLimitCount[scope]++ - } - } } displayGroupID := int64(0) @@ -157,9 +140,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi item.RateLimitRemainingSec = &remainingSec } } - if len(scopeRateLimits) > 0 { - item.ScopeRateLimits = scopeRateLimits - } if isOverloaded && acc.OverloadUntil != nil { item.OverloadUntil = acc.OverloadUntil remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds()) diff --git a/backend/internal/service/ops_aggregation_service.go b/backend/internal/service/ops_aggregation_service.go index 972462ec..ec77fe12 100644 --- a/backend/internal/service/ops_aggregation_service.go +++ b/backend/internal/service/ops_aggregation_service.go @@ -5,12 +5,12 @@ import ( "database/sql" "errors" "fmt" - "log" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/google/uuid" "github.com/redis/go-redis/v9" ) @@ -190,7 +190,7 @@ func (s *OpsAggregationService) aggregateHourly() { latest, ok, err := s.opsRepo.GetLatestHourlyBucketStart(ctxMax) cancelMax() if err != nil { - log.Printf("[OpsAggregation][hourly] failed to read latest bucket: %v", err) + logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][hourly] failed to read latest bucket: %v", err) } else if ok { candidate := latest.Add(-opsAggHourlyOverlap) if candidate.After(start) { @@ -209,7 +209,7 @@ func (s *OpsAggregationService) aggregateHourly() { chunkEnd := minTime(cursor.Add(opsAggHourlyChunk), end) if err := s.opsRepo.UpsertHourlyMetrics(ctx, cursor, chunkEnd); err != nil { aggErr = err - log.Printf("[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err) + logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err) break } } @@ -288,7 +288,7 @@ func (s *OpsAggregationService) aggregateDaily() { latest, ok, err := s.opsRepo.GetLatestDailyBucketDate(ctxMax) cancelMax() if err != nil { - log.Printf("[OpsAggregation][daily] failed to read latest bucket: %v", err) + logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][daily] failed to read latest bucket: %v", err) } else if ok { candidate := latest.Add(-opsAggDailyOverlap) if candidate.After(start) { @@ -307,7 +307,7 @@ func (s *OpsAggregationService) aggregateDaily() { chunkEnd := minTime(cursor.Add(opsAggDailyChunk), end) if err := s.opsRepo.UpsertDailyMetrics(ctx, cursor, chunkEnd); err != nil { aggErr = err - log.Printf("[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err) + logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err) break } } @@ -427,7 +427,7 @@ func (s *OpsAggregationService) maybeLogSkip(prefix string) { if prefix == "" { prefix = "[OpsAggregation]" } - log.Printf("%s leader lock held by another instance; skipping", prefix) + logger.LegacyPrintf("service.ops_aggregation", "%s leader lock held by another instance; skipping", prefix) } func utcFloorToHour(t time.Time) time.Time { diff --git a/backend/internal/service/ops_alert_evaluator_service.go b/backend/internal/service/ops_alert_evaluator_service.go index 7c62e247..169a5e32 100644 --- a/backend/internal/service/ops_alert_evaluator_service.go +++ b/backend/internal/service/ops_alert_evaluator_service.go @@ -3,7 +3,6 @@ package service import ( "context" "fmt" - "log" "math" "strconv" "strings" @@ -11,6 +10,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/google/uuid" "github.com/redis/go-redis/v9" ) @@ -186,7 +186,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { rules, err := s.opsRepo.ListAlertRules(ctx) if err != nil { s.recordHeartbeatError(runAt, time.Since(startedAt), err) - log.Printf("[OpsAlertEvaluator] list rules failed: %v", err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] list rules failed: %v", err) return } @@ -236,7 +236,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { activeEvent, err := s.opsRepo.GetActiveAlertEvent(ctx, rule.ID) if err != nil { - log.Printf("[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err) continue } @@ -258,7 +258,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID) if err != nil { - log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err) continue } if latestEvent != nil && rule.CooldownMinutes > 0 { @@ -283,7 +283,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { created, err := s.opsRepo.CreateAlertEvent(ctx, firedEvent) if err != nil { - log.Printf("[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err) continue } @@ -300,7 +300,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { if activeEvent != nil { resolvedAt := now if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil { - log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err) } else { eventsResolved++ } @@ -779,7 +779,7 @@ func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, loc } if s.redisClient == nil { s.warnNoRedisOnce.Do(func() { - log.Printf("[OpsAlertEvaluator] redis not configured; running without distributed lock") + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] redis not configured; running without distributed lock") }) return nil, true } @@ -797,7 +797,7 @@ func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, loc // Prefer fail-closed to avoid duplicate evaluators stampeding the DB when Redis is flaky. // Single-node deployments can disable the distributed lock via runtime settings. s.warnNoRedisOnce.Do(func() { - log.Printf("[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err) }) return nil, false } @@ -819,7 +819,7 @@ func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) { return } s.skipLogAt = now - log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key) } func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) { diff --git a/backend/internal/service/ops_alert_evaluator_service_test.go b/backend/internal/service/ops_alert_evaluator_service_test.go index 068ab6bb..83d358a3 100644 --- a/backend/internal/service/ops_alert_evaluator_service_test.go +++ b/backend/internal/service/ops_alert_evaluator_service_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" ) +var _ OpsRepository = (*stubOpsRepo)(nil) + type stubOpsRepo struct { OpsRepository overview *OpsDashboardOverview diff --git a/backend/internal/service/ops_cleanup_service.go b/backend/internal/service/ops_cleanup_service.go index 1ade7176..1cae6fe5 100644 --- a/backend/internal/service/ops_cleanup_service.go +++ b/backend/internal/service/ops_cleanup_service.go @@ -4,12 +4,12 @@ import ( "context" "database/sql" "fmt" - "log" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/google/uuid" "github.com/redis/go-redis/v9" "github.com/robfig/cron/v3" @@ -75,11 +75,11 @@ func (s *OpsCleanupService) Start() { return } if s.cfg != nil && !s.cfg.Ops.Cleanup.Enabled { - log.Printf("[OpsCleanup] not started (disabled)") + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (disabled)") return } if s.opsRepo == nil || s.db == nil { - log.Printf("[OpsCleanup] not started (missing deps)") + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (missing deps)") return } @@ -99,12 +99,12 @@ func (s *OpsCleanupService) Start() { c := cron.New(cron.WithParser(opsCleanupCronParser), cron.WithLocation(loc)) _, err := c.AddFunc(schedule, func() { s.runScheduled() }) if err != nil { - log.Printf("[OpsCleanup] not started (invalid schedule=%q): %v", schedule, err) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (invalid schedule=%q): %v", schedule, err) return } s.cron = c s.cron.Start() - log.Printf("[OpsCleanup] started (schedule=%q tz=%s)", schedule, loc.String()) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] started (schedule=%q tz=%s)", schedule, loc.String()) }) } @@ -118,7 +118,7 @@ func (s *OpsCleanupService) Stop() { select { case <-ctx.Done(): case <-time.After(3 * time.Second): - log.Printf("[OpsCleanup] cron stop timed out") + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cron stop timed out") } } }) @@ -146,17 +146,19 @@ func (s *OpsCleanupService) runScheduled() { counts, err := s.runCleanupOnce(ctx) if err != nil { s.recordHeartbeatError(runAt, time.Since(startedAt), err) - log.Printf("[OpsCleanup] cleanup failed: %v", err) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cleanup failed: %v", err) return } s.recordHeartbeatSuccess(runAt, time.Since(startedAt), counts) - log.Printf("[OpsCleanup] cleanup complete: %s", counts) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cleanup complete: %s", counts) } type opsCleanupDeletedCounts struct { errorLogs int64 retryAttempts int64 alertEvents int64 + systemLogs int64 + logAudits int64 systemMetrics int64 hourlyPreagg int64 dailyPreagg int64 @@ -164,10 +166,12 @@ type opsCleanupDeletedCounts struct { func (c opsCleanupDeletedCounts) String() string { return fmt.Sprintf( - "error_logs=%d retry_attempts=%d alert_events=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d", + "error_logs=%d retry_attempts=%d alert_events=%d system_logs=%d log_audits=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d", c.errorLogs, c.retryAttempts, c.alertEvents, + c.systemLogs, + c.logAudits, c.systemMetrics, c.hourlyPreagg, c.dailyPreagg, @@ -204,6 +208,18 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet return out, err } out.alertEvents = n + + n, err = deleteOldRowsByID(ctx, s.db, "ops_system_logs", "created_at", cutoff, batchSize, false) + if err != nil { + return out, err + } + out.systemLogs = n + + n, err = deleteOldRowsByID(ctx, s.db, "ops_system_log_cleanup_audits", "created_at", cutoff, batchSize, false) + if err != nil { + return out, err + } + out.logAudits = n } // Minute-level metrics snapshots. @@ -315,11 +331,11 @@ func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), b } // Redis error: fall back to DB advisory lock. s.warnNoRedisOnce.Do(func() { - log.Printf("[OpsCleanup] leader lock SetNX failed; falling back to DB advisory lock: %v", err) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] leader lock SetNX failed; falling back to DB advisory lock: %v", err) }) } else { s.warnNoRedisOnce.Do(func() { - log.Printf("[OpsCleanup] redis not configured; using DB advisory lock") + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] redis not configured; using DB advisory lock") }) } diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go index f6541d08..92b37e73 100644 --- a/backend/internal/service/ops_concurrency.go +++ b/backend/internal/service/ops_concurrency.go @@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{ Page: page, PageSize: opsAccountsPageSize, - }, platformFilter, "", "", "") + }, platformFilter, "", "", "", 0) if err != nil { return nil, err } diff --git a/backend/internal/service/ops_log_runtime.go b/backend/internal/service/ops_log_runtime.go new file mode 100644 index 00000000..ed8aefa9 --- /dev/null +++ b/backend/internal/service/ops_log_runtime.go @@ -0,0 +1,267 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "go.uber.org/zap" +) + +func defaultOpsRuntimeLogConfig(cfg *config.Config) *OpsRuntimeLogConfig { + out := &OpsRuntimeLogConfig{ + Level: "info", + EnableSampling: false, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + } + if cfg == nil { + return out + } + out.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level)) + out.EnableSampling = cfg.Log.Sampling.Enabled + out.SamplingInitial = cfg.Log.Sampling.Initial + out.SamplingNext = cfg.Log.Sampling.Thereafter + out.Caller = cfg.Log.Caller + out.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel)) + if cfg.Ops.Cleanup.ErrorLogRetentionDays > 0 { + out.RetentionDays = cfg.Ops.Cleanup.ErrorLogRetentionDays + } + return out +} + +func normalizeOpsRuntimeLogConfig(cfg *OpsRuntimeLogConfig, defaults *OpsRuntimeLogConfig) { + if cfg == nil || defaults == nil { + return + } + cfg.Level = strings.ToLower(strings.TrimSpace(cfg.Level)) + if cfg.Level == "" { + cfg.Level = defaults.Level + } + cfg.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.StacktraceLevel)) + if cfg.StacktraceLevel == "" { + cfg.StacktraceLevel = defaults.StacktraceLevel + } + if cfg.SamplingInitial <= 0 { + cfg.SamplingInitial = defaults.SamplingInitial + } + if cfg.SamplingNext <= 0 { + cfg.SamplingNext = defaults.SamplingNext + } + if cfg.RetentionDays <= 0 { + cfg.RetentionDays = defaults.RetentionDays + } +} + +func validateOpsRuntimeLogConfig(cfg *OpsRuntimeLogConfig) error { + if cfg == nil { + return errors.New("invalid config") + } + switch strings.ToLower(strings.TrimSpace(cfg.Level)) { + case "debug", "info", "warn", "error": + default: + return errors.New("level must be one of: debug/info/warn/error") + } + switch strings.ToLower(strings.TrimSpace(cfg.StacktraceLevel)) { + case "none", "error", "fatal": + default: + return errors.New("stacktrace_level must be one of: none/error/fatal") + } + if cfg.SamplingInitial <= 0 { + return errors.New("sampling_initial must be positive") + } + if cfg.SamplingNext <= 0 { + return errors.New("sampling_thereafter must be positive") + } + if cfg.RetentionDays < 1 || cfg.RetentionDays > 3650 { + return errors.New("retention_days must be between 1 and 3650") + } + return nil +} + +func (s *OpsService) GetRuntimeLogConfig(ctx context.Context) (*OpsRuntimeLogConfig, error) { + if s == nil || s.settingRepo == nil { + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + defaultCfg := defaultOpsRuntimeLogConfig(cfg) + return defaultCfg, nil + } + defaultCfg := defaultOpsRuntimeLogConfig(s.cfg) + if ctx == nil { + ctx = context.Background() + } + + raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsRuntimeLogConfig) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + b, _ := json.Marshal(defaultCfg) + _ = s.settingRepo.Set(ctx, SettingKeyOpsRuntimeLogConfig, string(b)) + return defaultCfg, nil + } + return nil, err + } + + cfg := &OpsRuntimeLogConfig{} + if err := json.Unmarshal([]byte(raw), cfg); err != nil { + return defaultCfg, nil + } + normalizeOpsRuntimeLogConfig(cfg, defaultCfg) + return cfg, nil +} + +func (s *OpsService) UpdateRuntimeLogConfig(ctx context.Context, req *OpsRuntimeLogConfig, operatorID int64) (*OpsRuntimeLogConfig, error) { + if s == nil || s.settingRepo == nil { + return nil, errors.New("setting repository not initialized") + } + if req == nil { + return nil, errors.New("invalid config") + } + if ctx == nil { + ctx = context.Background() + } + if operatorID <= 0 { + return nil, errors.New("invalid operator id") + } + + oldCfg, err := s.GetRuntimeLogConfig(ctx) + if err != nil { + return nil, err + } + next := *req + normalizeOpsRuntimeLogConfig(&next, defaultOpsRuntimeLogConfig(s.cfg)) + if err := validateOpsRuntimeLogConfig(&next); err != nil { + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, &next, "validation_failed: "+err.Error()) + return nil, err + } + + if err := applyOpsRuntimeLogConfig(&next); err != nil { + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, &next, "apply_failed: "+err.Error()) + return nil, err + } + + next.Source = "runtime_setting" + next.UpdatedAt = time.Now().UTC().Format(time.RFC3339Nano) + next.UpdatedByUserID = operatorID + + encoded, err := json.Marshal(&next) + if err != nil { + return nil, err + } + if err := s.settingRepo.Set(ctx, SettingKeyOpsRuntimeLogConfig, string(encoded)); err != nil { + // 存储失败时回滚到旧配置,避免内存状态与持久化状态不一致。 + _ = applyOpsRuntimeLogConfig(oldCfg) + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, &next, "persist_failed: "+err.Error()) + return nil, err + } + + s.auditRuntimeLogConfigChange(operatorID, oldCfg, &next, "updated") + + return &next, nil +} + +func (s *OpsService) ResetRuntimeLogConfig(ctx context.Context, operatorID int64) (*OpsRuntimeLogConfig, error) { + if s == nil || s.settingRepo == nil { + return nil, errors.New("setting repository not initialized") + } + if ctx == nil { + ctx = context.Background() + } + if operatorID <= 0 { + return nil, errors.New("invalid operator id") + } + + oldCfg, err := s.GetRuntimeLogConfig(ctx) + if err != nil { + return nil, err + } + + resetCfg := defaultOpsRuntimeLogConfig(s.cfg) + normalizeOpsRuntimeLogConfig(resetCfg, defaultOpsRuntimeLogConfig(s.cfg)) + if err := validateOpsRuntimeLogConfig(resetCfg); err != nil { + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, resetCfg, "reset_validation_failed: "+err.Error()) + return nil, err + } + if err := applyOpsRuntimeLogConfig(resetCfg); err != nil { + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, resetCfg, "reset_apply_failed: "+err.Error()) + return nil, err + } + + // 清理 runtime 覆盖配置,回退到 env/yaml baseline。 + if err := s.settingRepo.Delete(ctx, SettingKeyOpsRuntimeLogConfig); err != nil && !errors.Is(err, ErrSettingNotFound) { + _ = applyOpsRuntimeLogConfig(oldCfg) + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, resetCfg, "reset_persist_failed: "+err.Error()) + return nil, err + } + + now := time.Now().UTC().Format(time.RFC3339Nano) + resetCfg.Source = "baseline" + resetCfg.UpdatedAt = now + resetCfg.UpdatedByUserID = operatorID + + s.auditRuntimeLogConfigChange(operatorID, oldCfg, resetCfg, "reset") + return resetCfg, nil +} + +func applyOpsRuntimeLogConfig(cfg *OpsRuntimeLogConfig) error { + if cfg == nil { + return fmt.Errorf("nil runtime log config") + } + if err := logger.Reconfigure(func(opts *logger.InitOptions) error { + opts.Level = strings.ToLower(strings.TrimSpace(cfg.Level)) + opts.Caller = cfg.Caller + opts.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.StacktraceLevel)) + opts.Sampling.Enabled = cfg.EnableSampling + opts.Sampling.Initial = cfg.SamplingInitial + opts.Sampling.Thereafter = cfg.SamplingNext + return nil + }); err != nil { + return err + } + return nil +} + +func (s *OpsService) applyRuntimeLogConfigOnStartup(ctx context.Context) { + if s == nil { + return + } + cfg, err := s.GetRuntimeLogConfig(ctx) + if err != nil { + return + } + _ = applyOpsRuntimeLogConfig(cfg) +} + +func (s *OpsService) auditRuntimeLogConfigChange(operatorID int64, oldCfg *OpsRuntimeLogConfig, newCfg *OpsRuntimeLogConfig, action string) { + oldRaw, _ := json.Marshal(oldCfg) + newRaw, _ := json.Marshal(newCfg) + logger.With( + zap.String("component", "audit.log_config_change"), + zap.String("action", strings.TrimSpace(action)), + zap.Int64("operator_id", operatorID), + zap.String("old", string(oldRaw)), + zap.String("new", string(newRaw)), + ).Info("runtime log config changed") +} + +func (s *OpsService) auditRuntimeLogConfigFailure(operatorID int64, oldCfg *OpsRuntimeLogConfig, newCfg *OpsRuntimeLogConfig, reason string) { + oldRaw, _ := json.Marshal(oldCfg) + newRaw, _ := json.Marshal(newCfg) + logger.With( + zap.String("component", "audit.log_config_change"), + zap.String("action", "failed"), + zap.Int64("operator_id", operatorID), + zap.String("reason", strings.TrimSpace(reason)), + zap.String("old", string(oldRaw)), + zap.String("new", string(newRaw)), + ).Warn("runtime log config change failed") +} diff --git a/backend/internal/service/ops_log_runtime_test.go b/backend/internal/service/ops_log_runtime_test.go new file mode 100644 index 00000000..658b4812 --- /dev/null +++ b/backend/internal/service/ops_log_runtime_test.go @@ -0,0 +1,570 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +type runtimeSettingRepoStub struct { + values map[string]string + deleted map[string]bool + setCalls int + getValueFn func(key string) (string, error) + setFn func(key, value string) error + deleteFn func(key string) error +} + +func newRuntimeSettingRepoStub() *runtimeSettingRepoStub { + return &runtimeSettingRepoStub{ + values: map[string]string{}, + deleted: map[string]bool{}, + } +} + +func (s *runtimeSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + value, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &Setting{Key: key, Value: value}, nil +} + +func (s *runtimeSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if s.getValueFn != nil { + return s.getValueFn(key) + } + value, ok := s.values[key] + if !ok { + return "", ErrSettingNotFound + } + return value, nil +} + +func (s *runtimeSettingRepoStub) Set(_ context.Context, key, value string) error { + if s.setFn != nil { + if err := s.setFn(key, value); err != nil { + return err + } + } + s.values[key] = value + s.setCalls++ + return nil +} + +func (s *runtimeSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *runtimeSettingRepoStub) SetMultiple(_ context.Context, settings map[string]string) error { + for key, value := range settings { + s.values[key] = value + } + return nil +} + +func (s *runtimeSettingRepoStub) GetAll(_ context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *runtimeSettingRepoStub) Delete(_ context.Context, key string) error { + if s.deleteFn != nil { + if err := s.deleteFn(key); err != nil { + return err + } + } + if _, ok := s.values[key]; !ok { + return ErrSettingNotFound + } + delete(s.values, key) + s.deleted[key] = true + return nil +} + +func TestUpdateRuntimeLogConfig_InvalidConfigShouldNotApply(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{ + Level: "trace", + EnableSampling: true, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + }, 1) + if err == nil { + t.Fatalf("expected validation error") + } + if logger.CurrentLevel() != "info" { + t.Fatalf("logger level changed unexpectedly: %s", logger.CurrentLevel()) + } + if repo.setCalls != 1 { + // GetRuntimeLogConfig() 会在 key 缺失时写入默认值,此处应只有这一次持久化。 + t.Fatalf("unexpected set calls: %d", repo.setCalls) + } +} + +func TestResetRuntimeLogConfig_ShouldFallbackToBaseline(t *testing.T) { + repo := newRuntimeSettingRepoStub() + existing := &OpsRuntimeLogConfig{ + Level: "debug", + EnableSampling: true, + SamplingInitial: 50, + SamplingNext: 50, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 60, + Source: "runtime_setting", + } + raw, _ := json.Marshal(existing) + repo.values[SettingKeyOpsRuntimeLogConfig] = string(raw) + + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "warn", + Caller: false, + StacktraceLevel: "fatal", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + Ops: config.OpsConfig{ + Cleanup: config.OpsCleanupConfig{ + ErrorLogRetentionDays: 45, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + resetCfg, err := svc.ResetRuntimeLogConfig(context.Background(), 9) + if err != nil { + t.Fatalf("ResetRuntimeLogConfig() error: %v", err) + } + if resetCfg.Source != "baseline" { + t.Fatalf("source = %q, want baseline", resetCfg.Source) + } + if resetCfg.Level != "warn" { + t.Fatalf("level = %q, want warn", resetCfg.Level) + } + if resetCfg.RetentionDays != 45 { + t.Fatalf("retention_days = %d, want 45", resetCfg.RetentionDays) + } + if logger.CurrentLevel() != "warn" { + t.Fatalf("logger level = %q, want warn", logger.CurrentLevel()) + } + if !repo.deleted[SettingKeyOpsRuntimeLogConfig] { + t.Fatalf("runtime setting key should be deleted") + } +} + +func TestResetRuntimeLogConfig_InvalidOperator(t *testing.T) { + svc := &OpsService{settingRepo: newRuntimeSettingRepoStub()} + _, err := svc.ResetRuntimeLogConfig(context.Background(), 0) + if err == nil { + t.Fatalf("expected invalid operator error") + } + if err.Error() != "invalid operator id" { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestGetRuntimeLogConfig_InvalidJSONFallback(t *testing.T) { + repo := newRuntimeSettingRepoStub() + repo.values[SettingKeyOpsRuntimeLogConfig] = `{invalid-json}` + + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "warn", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + got, err := svc.GetRuntimeLogConfig(context.Background()) + if err != nil { + t.Fatalf("GetRuntimeLogConfig() error: %v", err) + } + if got.Level != "warn" { + t.Fatalf("level = %q, want warn", got.Level) + } +} + +func TestUpdateRuntimeLogConfig_PersistFailureRollback(t *testing.T) { + repo := newRuntimeSettingRepoStub() + oldCfg := &OpsRuntimeLogConfig{ + Level: "info", + EnableSampling: false, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + } + raw, _ := json.Marshal(oldCfg) + repo.values[SettingKeyOpsRuntimeLogConfig] = string(raw) + repo.setFn = func(key, value string) error { + if key == SettingKeyOpsRuntimeLogConfig { + return errors.New("db down") + } + return nil + } + + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{ + Level: "debug", + EnableSampling: false, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + }, 5) + if err == nil { + t.Fatalf("expected persist error") + } + // Persist failure should rollback runtime level back to old effective level. + if logger.CurrentLevel() != "info" { + t.Fatalf("logger level should rollback to info, got %s", logger.CurrentLevel()) + } +} + +func TestApplyRuntimeLogConfigOnStartup(t *testing.T) { + repo := newRuntimeSettingRepoStub() + cfgRaw := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}` + repo.values[SettingKeyOpsRuntimeLogConfig] = cfgRaw + + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + svc.applyRuntimeLogConfigOnStartup(context.Background()) + if logger.CurrentLevel() != "debug" { + t.Fatalf("expected startup apply debug, got %s", logger.CurrentLevel()) + } +} + +func TestDefaultNormalizeAndValidateRuntimeLogConfig(t *testing.T) { + defaults := defaultOpsRuntimeLogConfig(&config.Config{ + Log: config.LogConfig{ + Level: "DEBUG", + Caller: false, + StacktraceLevel: "FATAL", + Sampling: config.LogSamplingConfig{ + Enabled: true, + Initial: 50, + Thereafter: 20, + }, + }, + Ops: config.OpsConfig{ + Cleanup: config.OpsCleanupConfig{ + ErrorLogRetentionDays: 7, + }, + }, + }) + if defaults.Level != "debug" || defaults.StacktraceLevel != "fatal" || defaults.RetentionDays != 7 { + t.Fatalf("unexpected defaults: %+v", defaults) + } + + cfg := &OpsRuntimeLogConfig{ + Level: " ", + EnableSampling: true, + SamplingInitial: 0, + SamplingNext: -1, + Caller: true, + StacktraceLevel: "", + RetentionDays: 0, + } + normalizeOpsRuntimeLogConfig(cfg, defaults) + if cfg.Level != "debug" || cfg.StacktraceLevel != "fatal" { + t.Fatalf("normalize level/stacktrace failed: %+v", cfg) + } + if cfg.SamplingInitial != 50 || cfg.SamplingNext != 20 || cfg.RetentionDays != 7 { + t.Fatalf("normalize numeric defaults failed: %+v", cfg) + } + if err := validateOpsRuntimeLogConfig(cfg); err != nil { + t.Fatalf("validate normalized config should pass: %v", err) + } +} + +func TestValidateRuntimeLogConfigErrors(t *testing.T) { + cases := []struct { + name string + cfg *OpsRuntimeLogConfig + }{ + {name: "nil", cfg: nil}, + {name: "bad level", cfg: &OpsRuntimeLogConfig{Level: "trace", StacktraceLevel: "error", SamplingInitial: 1, SamplingNext: 1, RetentionDays: 1}}, + {name: "bad stack", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "warn", SamplingInitial: 1, SamplingNext: 1, RetentionDays: 1}}, + {name: "bad initial", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "error", SamplingInitial: 0, SamplingNext: 1, RetentionDays: 1}}, + {name: "bad next", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "error", SamplingInitial: 1, SamplingNext: 0, RetentionDays: 1}}, + {name: "bad retention", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "error", SamplingInitial: 1, SamplingNext: 1, RetentionDays: 0}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if err := validateOpsRuntimeLogConfig(tc.cfg); err == nil { + t.Fatalf("expected validation error") + } + }) + } +} + +func TestGetRuntimeLogConfigFallbackAndErrors(t *testing.T) { + var nilSvc *OpsService + cfg, err := nilSvc.GetRuntimeLogConfig(context.Background()) + if err != nil { + t.Fatalf("nil svc should fallback default: %v", err) + } + if cfg.Level != "info" { + t.Fatalf("unexpected nil svc default level: %s", cfg.Level) + } + + repo := newRuntimeSettingRepoStub() + repo.getValueFn = func(key string) (string, error) { + return "", errors.New("boom") + } + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "warn", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + if _, err := svc.GetRuntimeLogConfig(context.Background()); err == nil { + t.Fatalf("expected get value error") + } +} + +func TestUpdateRuntimeLogConfig_PreconditionErrors(t *testing.T) { + svc := &OpsService{} + if _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{}, 1); err == nil { + t.Fatalf("expected setting repo not initialized") + } + + svc = &OpsService{settingRepo: newRuntimeSettingRepoStub()} + if _, err := svc.UpdateRuntimeLogConfig(context.Background(), nil, 1); err == nil { + t.Fatalf("expected invalid config") + } + if _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{ + Level: "info", + StacktraceLevel: "error", + SamplingInitial: 1, + SamplingNext: 1, + RetentionDays: 1, + }, 0); err == nil { + t.Fatalf("expected invalid operator") + } +} + +func TestUpdateRuntimeLogConfig_Success(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + next, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{ + Level: "debug", + EnableSampling: false, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + }, 2) + if err != nil { + t.Fatalf("UpdateRuntimeLogConfig() error: %v", err) + } + if next.Source != "runtime_setting" || next.UpdatedByUserID != 2 || next.UpdatedAt == "" { + t.Fatalf("unexpected metadata: %+v", next) + } + if logger.CurrentLevel() != "debug" { + t.Fatalf("expected applied level debug, got %s", logger.CurrentLevel()) + } +} + +func TestResetRuntimeLogConfig_IgnoreNotFoundDelete(t *testing.T) { + repo := newRuntimeSettingRepoStub() + repo.deleteFn = func(key string) error { return ErrSettingNotFound } + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + if _, err := svc.ResetRuntimeLogConfig(context.Background(), 1); err != nil { + t.Fatalf("reset should ignore ErrSettingNotFound: %v", err) + } +} + +func TestApplyRuntimeLogConfigHelpers(t *testing.T) { + if err := applyOpsRuntimeLogConfig(nil); err == nil { + t.Fatalf("expected nil config error") + } + + normalizeOpsRuntimeLogConfig(nil, &OpsRuntimeLogConfig{Level: "info"}) + normalizeOpsRuntimeLogConfig(&OpsRuntimeLogConfig{Level: "debug"}, nil) + + var nilSvc *OpsService + nilSvc.applyRuntimeLogConfigOnStartup(context.Background()) +} diff --git a/backend/internal/service/ops_models.go b/backend/internal/service/ops_models.go index 347cd52b..2ed06d90 100644 --- a/backend/internal/service/ops_models.go +++ b/backend/internal/service/ops_models.go @@ -2,6 +2,21 @@ package service import "time" +type OpsSystemLog struct { + ID int64 `json:"id"` + CreatedAt time.Time `json:"created_at"` + Level string `json:"level"` + Component string `json:"component"` + Message string `json:"message"` + RequestID string `json:"request_id"` + ClientRequestID string `json:"client_request_id"` + UserID *int64 `json:"user_id"` + AccountID *int64 `json:"account_id"` + Platform string `json:"platform"` + Model string `json:"model"` + Extra map[string]any `json:"extra,omitempty"` +} + type OpsErrorLog struct { ID int64 `json:"id"` CreatedAt time.Time `json:"created_at"` diff --git a/backend/internal/service/ops_openai_token_stats.go b/backend/internal/service/ops_openai_token_stats.go new file mode 100644 index 00000000..63f88ba0 --- /dev/null +++ b/backend/internal/service/ops_openai_token_stats.go @@ -0,0 +1,55 @@ +package service + +import ( + "context" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func (s *OpsService) GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) { + if err := s.RequireMonitoringEnabled(ctx); err != nil { + return nil, err + } + if s.opsRepo == nil { + return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available") + } + if filter == nil { + return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required") + } + if filter.StartTime.IsZero() || filter.EndTime.IsZero() { + return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required") + } + if filter.StartTime.After(filter.EndTime) { + return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") + } + + if filter.GroupID != nil && *filter.GroupID <= 0 { + return nil, infraerrors.BadRequest("OPS_GROUP_ID_INVALID", "group_id must be > 0") + } + + // top_n cannot be mixed with page/page_size params. + if filter.TopN > 0 && (filter.Page > 0 || filter.PageSize > 0) { + return nil, infraerrors.BadRequest("OPS_PAGINATION_CONFLICT", "top_n cannot be used with page/page_size") + } + + if filter.TopN > 0 { + if filter.TopN < 1 || filter.TopN > 100 { + return nil, infraerrors.BadRequest("OPS_TOPN_INVALID", "top_n must be between 1 and 100") + } + } else { + if filter.Page <= 0 { + filter.Page = 1 + } + if filter.PageSize <= 0 { + filter.PageSize = 20 + } + if filter.Page < 1 { + return nil, infraerrors.BadRequest("OPS_PAGE_INVALID", "page must be >= 1") + } + if filter.PageSize < 1 || filter.PageSize > 100 { + return nil, infraerrors.BadRequest("OPS_PAGE_SIZE_INVALID", "page_size must be between 1 and 100") + } + } + + return s.opsRepo.GetOpenAITokenStats(ctx, filter) +} diff --git a/backend/internal/service/ops_openai_token_stats_models.go b/backend/internal/service/ops_openai_token_stats_models.go new file mode 100644 index 00000000..ef40fa1f --- /dev/null +++ b/backend/internal/service/ops_openai_token_stats_models.go @@ -0,0 +1,54 @@ +package service + +import "time" + +type OpsOpenAITokenStatsFilter struct { + TimeRange string + StartTime time.Time + EndTime time.Time + + Platform string + GroupID *int64 + + // Pagination mode (default): page/page_size + Page int + PageSize int + + // TopN mode: top_n + TopN int +} + +func (f *OpsOpenAITokenStatsFilter) IsTopNMode() bool { + return f != nil && f.TopN > 0 +} + +type OpsOpenAITokenStatsItem struct { + Model string `json:"model"` + RequestCount int64 `json:"request_count"` + AvgTokensPerSec *float64 `json:"avg_tokens_per_sec"` + AvgFirstTokenMs *float64 `json:"avg_first_token_ms"` + TotalOutputTokens int64 `json:"total_output_tokens"` + AvgDurationMs int64 `json:"avg_duration_ms"` + RequestsWithFirstToken int64 `json:"requests_with_first_token"` +} + +type OpsOpenAITokenStatsResponse struct { + TimeRange string `json:"time_range"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + + Platform string `json:"platform,omitempty"` + GroupID *int64 `json:"group_id,omitempty"` + + Items []*OpsOpenAITokenStatsItem `json:"items"` + + // Total model rows before pagination/topN trimming. + Total int64 `json:"total"` + + // Pagination mode metadata. + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + + // TopN mode metadata. + TopN *int `json:"top_n,omitempty"` +} diff --git a/backend/internal/service/ops_openai_token_stats_test.go b/backend/internal/service/ops_openai_token_stats_test.go new file mode 100644 index 00000000..ee332f91 --- /dev/null +++ b/backend/internal/service/ops_openai_token_stats_test.go @@ -0,0 +1,162 @@ +package service + +import ( + "context" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type openAITokenStatsRepoStub struct { + OpsRepository + resp *OpsOpenAITokenStatsResponse + err error + captured *OpsOpenAITokenStatsFilter +} + +func (s *openAITokenStatsRepoStub) GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) { + s.captured = filter + if s.err != nil { + return nil, s.err + } + if s.resp != nil { + return s.resp, nil + } + return &OpsOpenAITokenStatsResponse{}, nil +} + +func TestOpsServiceGetOpenAITokenStats_Validation(t *testing.T) { + now := time.Now().UTC() + + tests := []struct { + name string + filter *OpsOpenAITokenStatsFilter + wantCode int + wantReason string + }{ + { + name: "filter 不能为空", + filter: nil, + wantCode: 400, + wantReason: "OPS_FILTER_REQUIRED", + }, + { + name: "start_time/end_time 必填", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: time.Time{}, + EndTime: now, + }, + wantCode: 400, + wantReason: "OPS_TIME_RANGE_REQUIRED", + }, + { + name: "start_time 不能晚于 end_time", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now, + EndTime: now.Add(-1 * time.Minute), + }, + wantCode: 400, + wantReason: "OPS_TIME_RANGE_INVALID", + }, + { + name: "group_id 必须大于 0", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now.Add(-time.Hour), + EndTime: now, + GroupID: int64Ptr(0), + }, + wantCode: 400, + wantReason: "OPS_GROUP_ID_INVALID", + }, + { + name: "top_n 与分页参数互斥", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now.Add(-time.Hour), + EndTime: now, + TopN: 10, + Page: 1, + }, + wantCode: 400, + wantReason: "OPS_PAGINATION_CONFLICT", + }, + { + name: "top_n 参数越界", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now.Add(-time.Hour), + EndTime: now, + TopN: 101, + }, + wantCode: 400, + wantReason: "OPS_TOPN_INVALID", + }, + { + name: "page_size 参数越界", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now.Add(-time.Hour), + EndTime: now, + Page: 1, + PageSize: 101, + }, + wantCode: 400, + wantReason: "OPS_PAGE_SIZE_INVALID", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := &OpsService{ + opsRepo: &openAITokenStatsRepoStub{}, + } + + _, err := svc.GetOpenAITokenStats(context.Background(), tt.filter) + require.Error(t, err) + require.Equal(t, tt.wantCode, infraerrors.Code(err)) + require.Equal(t, tt.wantReason, infraerrors.Reason(err)) + }) + } +} + +func TestOpsServiceGetOpenAITokenStats_DefaultPagination(t *testing.T) { + now := time.Now().UTC() + repo := &openAITokenStatsRepoStub{ + resp: &OpsOpenAITokenStatsResponse{ + Items: []*OpsOpenAITokenStatsItem{ + {Model: "gpt-4o-mini", RequestCount: 10}, + }, + Total: 1, + }, + } + svc := &OpsService{opsRepo: repo} + + filter := &OpsOpenAITokenStatsFilter{ + TimeRange: "30d", + StartTime: now.Add(-30 * 24 * time.Hour), + EndTime: now, + } + resp, err := svc.GetOpenAITokenStats(context.Background(), filter) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, repo.captured) + require.Equal(t, 1, repo.captured.Page) + require.Equal(t, 20, repo.captured.PageSize) + require.Equal(t, 0, repo.captured.TopN) +} + +func TestOpsServiceGetOpenAITokenStats_RepoUnavailable(t *testing.T) { + now := time.Now().UTC() + svc := &OpsService{} + + _, err := svc.GetOpenAITokenStats(context.Background(), &OpsOpenAITokenStatsFilter{ + TimeRange: "1h", + StartTime: now.Add(-time.Hour), + EndTime: now, + TopN: 10, + }) + require.Error(t, err) + require.Equal(t, 503, infraerrors.Code(err)) + require.Equal(t, "OPS_REPO_UNAVAILABLE", infraerrors.Reason(err)) +} + +func int64Ptr(v int64) *int64 { return &v } diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index 347b06b5..f3633eae 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -10,6 +10,10 @@ type OpsRepository interface { ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error) + BatchInsertSystemLogs(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) + ListSystemLogs(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) + DeleteSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) + InsertSystemLogCleanupAudit(ctx context.Context, input *OpsSystemLogCleanupAudit) error InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error) UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error @@ -27,6 +31,7 @@ type OpsRepository interface { GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) + GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) InsertSystemMetrics(ctx context.Context, input *OpsInsertSystemMetricsInput) error GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*OpsSystemMetricsSnapshot, error) @@ -98,6 +103,10 @@ type OpsInsertErrorLogInput struct { // It is set by OpsService.RecordError before persisting. UpstreamErrorsJSON *string + AuthLatencyMs *int64 + RoutingLatencyMs *int64 + UpstreamLatencyMs *int64 + ResponseLatencyMs *int64 TimeToFirstTokenMs *int64 RequestBodyJSON *string // sanitized json string (not raw bytes) @@ -200,6 +209,69 @@ type OpsInsertSystemMetricsInput struct { ConcurrencyQueueDepth *int } +type OpsInsertSystemLogInput struct { + CreatedAt time.Time + Level string + Component string + Message string + RequestID string + ClientRequestID string + UserID *int64 + AccountID *int64 + Platform string + Model string + ExtraJSON string +} + +type OpsSystemLogFilter struct { + StartTime *time.Time + EndTime *time.Time + + Level string + Component string + + RequestID string + ClientRequestID string + UserID *int64 + AccountID *int64 + Platform string + Model string + Query string + + Page int + PageSize int +} + +type OpsSystemLogCleanupFilter struct { + StartTime *time.Time + EndTime *time.Time + + Level string + Component string + + RequestID string + ClientRequestID string + UserID *int64 + AccountID *int64 + Platform string + Model string + Query string +} + +type OpsSystemLogList struct { + Logs []*OpsSystemLog `json:"logs"` + Total int `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} + +type OpsSystemLogCleanupAudit struct { + CreatedAt time.Time + OperatorID int64 + Conditions string + DeletedRows int64 +} + type OpsSystemMetricsSnapshot struct { ID int64 `json:"id"` CreatedAt time.Time `json:"created_at"` diff --git a/backend/internal/service/ops_realtime_models.go b/backend/internal/service/ops_realtime_models.go index 33029f59..a19ab355 100644 --- a/backend/internal/service/ops_realtime_models.go +++ b/backend/internal/service/ops_realtime_models.go @@ -50,24 +50,22 @@ type UserConcurrencyInfo struct { // PlatformAvailability aggregates account availability by platform. type PlatformAvailability struct { - Platform string `json:"platform"` - TotalAccounts int64 `json:"total_accounts"` - AvailableCount int64 `json:"available_count"` - RateLimitCount int64 `json:"rate_limit_count"` - ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"` - ErrorCount int64 `json:"error_count"` + Platform string `json:"platform"` + TotalAccounts int64 `json:"total_accounts"` + AvailableCount int64 `json:"available_count"` + RateLimitCount int64 `json:"rate_limit_count"` + ErrorCount int64 `json:"error_count"` } // GroupAvailability aggregates account availability by group. type GroupAvailability struct { - GroupID int64 `json:"group_id"` - GroupName string `json:"group_name"` - Platform string `json:"platform"` - TotalAccounts int64 `json:"total_accounts"` - AvailableCount int64 `json:"available_count"` - RateLimitCount int64 `json:"rate_limit_count"` - ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"` - ErrorCount int64 `json:"error_count"` + GroupID int64 `json:"group_id"` + GroupName string `json:"group_name"` + Platform string `json:"platform"` + TotalAccounts int64 `json:"total_accounts"` + AvailableCount int64 `json:"available_count"` + RateLimitCount int64 `json:"rate_limit_count"` + ErrorCount int64 `json:"error_count"` } // AccountAvailability represents current availability for a single account. @@ -85,11 +83,10 @@ type AccountAvailability struct { IsOverloaded bool `json:"is_overloaded"` HasError bool `json:"has_error"` - RateLimitResetAt *time.Time `json:"rate_limit_reset_at"` - RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"` - ScopeRateLimits map[string]int64 `json:"scope_rate_limits,omitempty"` - OverloadUntil *time.Time `json:"overload_until"` - OverloadRemainingSec *int64 `json:"overload_remaining_sec"` - ErrorMessage string `json:"error_message"` - TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"` + RateLimitResetAt *time.Time `json:"rate_limit_reset_at"` + RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"` + OverloadUntil *time.Time `json:"overload_until"` + OverloadRemainingSec *int64 `json:"overload_remaining_sec"` + ErrorMessage string `json:"error_message"` + TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"` } diff --git a/backend/internal/service/ops_repo_mock_test.go b/backend/internal/service/ops_repo_mock_test.go new file mode 100644 index 00000000..e250dea3 --- /dev/null +++ b/backend/internal/service/ops_repo_mock_test.go @@ -0,0 +1,196 @@ +package service + +import ( + "context" + "time" +) + +// opsRepoMock is a test-only OpsRepository implementation with optional function hooks. +type opsRepoMock struct { + BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) + ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) + DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) + InsertSystemLogCleanupAuditFn func(ctx context.Context, input *OpsSystemLogCleanupAudit) error +} + +func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) { + return 0, nil +} + +func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) { + return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil +} + +func (m *opsRepoMock) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) { + return &OpsErrorLogDetail{}, nil +} + +func (m *opsRepoMock) ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error) { + return []*OpsRequestDetail{}, 0, nil +} + +func (m *opsRepoMock) BatchInsertSystemLogs(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) { + if m.BatchInsertSystemLogsFn != nil { + return m.BatchInsertSystemLogsFn(ctx, inputs) + } + return int64(len(inputs)), nil +} + +func (m *opsRepoMock) ListSystemLogs(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) { + if m.ListSystemLogsFn != nil { + return m.ListSystemLogsFn(ctx, filter) + } + return &OpsSystemLogList{Logs: []*OpsSystemLog{}, Total: 0, Page: 1, PageSize: 50}, nil +} + +func (m *opsRepoMock) DeleteSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + if m.DeleteSystemLogsFn != nil { + return m.DeleteSystemLogsFn(ctx, filter) + } + return 0, nil +} + +func (m *opsRepoMock) InsertSystemLogCleanupAudit(ctx context.Context, input *OpsSystemLogCleanupAudit) error { + if m.InsertSystemLogCleanupAuditFn != nil { + return m.InsertSystemLogCleanupAuditFn(ctx, input) + } + return nil +} + +func (m *opsRepoMock) InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error) { + return 0, nil +} + +func (m *opsRepoMock) UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error { + return nil +} + +func (m *opsRepoMock) GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error) { + return nil, nil +} + +func (m *opsRepoMock) ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error) { + return []*OpsRetryAttempt{}, nil +} + +func (m *opsRepoMock) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error { + return nil +} + +func (m *opsRepoMock) GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error) { + return &OpsWindowStats{}, nil +} + +func (m *opsRepoMock) GetRealtimeTrafficSummary(ctx context.Context, filter *OpsDashboardFilter) (*OpsRealtimeTrafficSummary, error) { + return &OpsRealtimeTrafficSummary{}, nil +} + +func (m *opsRepoMock) GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error) { + return &OpsDashboardOverview{}, nil +} + +func (m *opsRepoMock) GetThroughputTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsThroughputTrendResponse, error) { + return &OpsThroughputTrendResponse{}, nil +} + +func (m *opsRepoMock) GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error) { + return &OpsLatencyHistogramResponse{}, nil +} + +func (m *opsRepoMock) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error) { + return &OpsErrorTrendResponse{}, nil +} + +func (m *opsRepoMock) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) { + return &OpsErrorDistributionResponse{}, nil +} + +func (m *opsRepoMock) GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) { + return &OpsOpenAITokenStatsResponse{}, nil +} + +func (m *opsRepoMock) InsertSystemMetrics(ctx context.Context, input *OpsInsertSystemMetricsInput) error { + return nil +} + +func (m *opsRepoMock) GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*OpsSystemMetricsSnapshot, error) { + return &OpsSystemMetricsSnapshot{}, nil +} + +func (m *opsRepoMock) UpsertJobHeartbeat(ctx context.Context, input *OpsUpsertJobHeartbeatInput) error { + return nil +} + +func (m *opsRepoMock) ListJobHeartbeats(ctx context.Context) ([]*OpsJobHeartbeat, error) { + return []*OpsJobHeartbeat{}, nil +} + +func (m *opsRepoMock) ListAlertRules(ctx context.Context) ([]*OpsAlertRule, error) { + return []*OpsAlertRule{}, nil +} + +func (m *opsRepoMock) CreateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error) { + return input, nil +} + +func (m *opsRepoMock) UpdateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error) { + return input, nil +} + +func (m *opsRepoMock) DeleteAlertRule(ctx context.Context, id int64) error { + return nil +} + +func (m *opsRepoMock) ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error) { + return []*OpsAlertEvent{}, nil +} + +func (m *opsRepoMock) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) { + return &OpsAlertEvent{}, nil +} + +func (m *opsRepoMock) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) { + return nil, nil +} + +func (m *opsRepoMock) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) { + return nil, nil +} + +func (m *opsRepoMock) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error) { + return event, nil +} + +func (m *opsRepoMock) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error { + return nil +} + +func (m *opsRepoMock) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error { + return nil +} + +func (m *opsRepoMock) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) { + return input, nil +} + +func (m *opsRepoMock) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) { + return false, nil +} + +func (m *opsRepoMock) UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error { + return nil +} + +func (m *opsRepoMock) UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error { + return nil +} + +func (m *opsRepoMock) GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error) { + return time.Time{}, false, nil +} + +func (m *opsRepoMock) GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error) { + return time.Time{}, false, nil +} + +var _ OpsRepository = (*opsRepoMock)(nil) diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index fbc800f2..f0daa3e2 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -12,7 +12,7 @@ import ( "strings" "time" - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/domain" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" "github.com/lib/pq" @@ -479,7 +479,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq attemptCtx := ctx if switches > 0 { - attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches) + attemptCtx = WithAccountSwitchCount(attemptCtx, switches, false) } exec := func() *opsRetryExecution { defer selection.ReleaseFunc() @@ -528,7 +528,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) { switch reqType { case opsRetryTypeMessages: - parsed, parseErr := ParseGatewayRequest(body) + parsed, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic) if parseErr != nil { return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr) } @@ -596,7 +596,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq if s.gatewayService == nil { return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"} } - parsedReq, parseErr := ParseGatewayRequest(body) + parsedReq, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic) if parseErr != nil { return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"} } @@ -674,6 +674,7 @@ func newOpsRetryContext(ctx context.Context, errorLog *OpsErrorLogDetail) (*gin. } c.Request = req + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) return c, w } diff --git a/backend/internal/service/ops_retry_context_test.go b/backend/internal/service/ops_retry_context_test.go new file mode 100644 index 00000000..a8c26ee4 --- /dev/null +++ b/backend/internal/service/ops_retry_context_test.go @@ -0,0 +1,47 @@ +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewOpsRetryContext_SetsHTTPTransportAndRequestHeaders(t *testing.T) { + errorLog := &OpsErrorLogDetail{ + OpsErrorLog: OpsErrorLog{ + RequestPath: "/openai/v1/responses", + }, + UserAgent: "ops-retry-agent/1.0", + RequestHeaders: `{ + "anthropic-beta":"beta-v1", + "ANTHROPIC-VERSION":"2023-06-01", + "authorization":"Bearer should-not-forward" + }`, + } + + c, w := newOpsRetryContext(context.Background(), errorLog) + require.NotNil(t, c) + require.NotNil(t, w) + require.NotNil(t, c.Request) + + require.Equal(t, "/openai/v1/responses", c.Request.URL.Path) + require.Equal(t, "application/json", c.Request.Header.Get("Content-Type")) + require.Equal(t, "ops-retry-agent/1.0", c.Request.Header.Get("User-Agent")) + require.Equal(t, "beta-v1", c.Request.Header.Get("anthropic-beta")) + require.Equal(t, "2023-06-01", c.Request.Header.Get("anthropic-version")) + require.Empty(t, c.Request.Header.Get("authorization"), "未在白名单内的敏感头不应被重放") + require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c)) +} + +func TestNewOpsRetryContext_InvalidHeadersJSONStillSetsHTTPTransport(t *testing.T) { + errorLog := &OpsErrorLogDetail{ + RequestHeaders: "{invalid-json", + } + + c, _ := newOpsRetryContext(context.Background(), errorLog) + require.NotNil(t, c) + require.NotNil(t, c.Request) + require.Equal(t, "/", c.Request.URL.Path) + require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c)) +} diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go index 9c121b8b..767d1704 100644 --- a/backend/internal/service/ops_service.go +++ b/backend/internal/service/ops_service.go @@ -20,6 +20,22 @@ const ( opsMaxStoredErrorBodyBytes = 20 * 1024 ) +// PrepareOpsRequestBodyForQueue 在入队前对请求体执行脱敏与裁剪,返回可直接写入 OpsInsertErrorLogInput 的字段。 +// 该方法用于避免异步队列持有大块原始请求体,减少错误风暴下的内存放大风险。 +func PrepareOpsRequestBodyForQueue(raw []byte) (requestBodyJSON *string, truncated bool, requestBodyBytes *int) { + if len(raw) == 0 { + return nil, false, nil + } + sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(raw, opsMaxStoredRequestBodyBytes) + if sanitized != "" { + out := sanitized + requestBodyJSON = &out + } + n := bytesLen + requestBodyBytes = &n + return requestBodyJSON, truncated, requestBodyBytes +} + // OpsService provides ingestion and query APIs for the Ops monitoring module. type OpsService struct { opsRepo OpsRepository @@ -37,6 +53,7 @@ type OpsService struct { openAIGatewayService *OpenAIGatewayService geminiCompatService *GeminiMessagesCompatService antigravityGatewayService *AntigravityGatewayService + systemLogSink *OpsSystemLogSink } func NewOpsService( @@ -50,8 +67,9 @@ func NewOpsService( openAIGatewayService *OpenAIGatewayService, geminiCompatService *GeminiMessagesCompatService, antigravityGatewayService *AntigravityGatewayService, + systemLogSink *OpsSystemLogSink, ) *OpsService { - return &OpsService{ + svc := &OpsService{ opsRepo: opsRepo, settingRepo: settingRepo, cfg: cfg, @@ -64,7 +82,10 @@ func NewOpsService( openAIGatewayService: openAIGatewayService, geminiCompatService: geminiCompatService, antigravityGatewayService: antigravityGatewayService, + systemLogSink: systemLogSink, } + svc.applyRuntimeLogConfigOnStartup(context.Background()) + return svc } func (s *OpsService) RequireMonitoringEnabled(ctx context.Context) error { @@ -127,12 +148,7 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn // Sanitize + trim request body (errors only). if len(rawRequestBody) > 0 { - sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(rawRequestBody, opsMaxStoredRequestBodyBytes) - if sanitized != "" { - entry.RequestBodyJSON = &sanitized - } - entry.RequestBodyTruncated = truncated - entry.RequestBodyBytes = &bytesLen + entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = PrepareOpsRequestBodyForQueue(rawRequestBody) } // Sanitize + truncate error_body to avoid storing sensitive data. diff --git a/backend/internal/service/ops_service_prepare_queue_test.go b/backend/internal/service/ops_service_prepare_queue_test.go new file mode 100644 index 00000000..d6f32c2d --- /dev/null +++ b/backend/internal/service/ops_service_prepare_queue_test.go @@ -0,0 +1,60 @@ +package service + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPrepareOpsRequestBodyForQueue_EmptyBody(t *testing.T) { + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(nil) + require.Nil(t, requestBodyJSON) + require.False(t, truncated) + require.Nil(t, requestBodyBytes) +} + +func TestPrepareOpsRequestBodyForQueue_InvalidJSON(t *testing.T) { + raw := []byte("{invalid-json") + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw) + require.Nil(t, requestBodyJSON) + require.False(t, truncated) + require.NotNil(t, requestBodyBytes) + require.Equal(t, len(raw), *requestBodyBytes) +} + +func TestPrepareOpsRequestBodyForQueue_RedactSensitiveFields(t *testing.T) { + raw := []byte(`{ + "model":"claude-3-5-sonnet-20241022", + "api_key":"sk-test-123", + "headers":{"authorization":"Bearer secret-token"}, + "messages":[{"role":"user","content":"hello"}] + }`) + + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw) + require.NotNil(t, requestBodyJSON) + require.NotNil(t, requestBodyBytes) + require.False(t, truncated) + require.Equal(t, len(raw), *requestBodyBytes) + + var body map[string]any + require.NoError(t, json.Unmarshal([]byte(*requestBodyJSON), &body)) + require.Equal(t, "[REDACTED]", body["api_key"]) + headers, ok := body["headers"].(map[string]any) + require.True(t, ok) + require.Equal(t, "[REDACTED]", headers["authorization"]) +} + +func TestPrepareOpsRequestBodyForQueue_LargeBodyTruncated(t *testing.T) { + largeMsg := strings.Repeat("x", opsMaxStoredRequestBodyBytes*2) + raw := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"` + largeMsg + `"}]}`) + + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw) + require.NotNil(t, requestBodyJSON) + require.NotNil(t, requestBodyBytes) + require.True(t, truncated) + require.Equal(t, len(raw), *requestBodyBytes) + require.LessOrEqual(t, len(*requestBodyJSON), opsMaxStoredRequestBodyBytes) + require.Contains(t, *requestBodyJSON, "request_body_truncated") +} diff --git a/backend/internal/service/ops_settings_models.go b/backend/internal/service/ops_settings_models.go index ecc62220..8b5359e3 100644 --- a/backend/internal/service/ops_settings_models.go +++ b/backend/internal/service/ops_settings_models.go @@ -68,6 +68,20 @@ type OpsMetricThresholds struct { UpstreamErrorRatePercentMax *float64 `json:"upstream_error_rate_percent_max,omitempty"` // 上游错误率高于此值变红 } +type OpsRuntimeLogConfig struct { + Level string `json:"level"` + EnableSampling bool `json:"enable_sampling"` + SamplingInitial int `json:"sampling_initial"` + SamplingNext int `json:"sampling_thereafter"` + Caller bool `json:"caller"` + StacktraceLevel string `json:"stacktrace_level"` + RetentionDays int `json:"retention_days"` + Source string `json:"source,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` + UpdatedByUserID int64 `json:"updated_by_user_id,omitempty"` + Extra map[string]any `json:"extra,omitempty"` +} + type OpsAlertRuntimeSettings struct { EvaluationIntervalSeconds int `json:"evaluation_interval_seconds"` diff --git a/backend/internal/service/ops_system_log_service.go b/backend/internal/service/ops_system_log_service.go new file mode 100644 index 00000000..f5a64803 --- /dev/null +++ b/backend/internal/service/ops_system_log_service.go @@ -0,0 +1,124 @@ +package service + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "log" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func (s *OpsService) ListSystemLogs(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) { + if err := s.RequireMonitoringEnabled(ctx); err != nil { + return nil, err + } + if s.opsRepo == nil { + return &OpsSystemLogList{ + Logs: []*OpsSystemLog{}, + Total: 0, + Page: 1, + PageSize: 50, + }, nil + } + if filter == nil { + filter = &OpsSystemLogFilter{} + } + if filter.Page <= 0 { + filter.Page = 1 + } + if filter.PageSize <= 0 { + filter.PageSize = 50 + } + if filter.PageSize > 200 { + filter.PageSize = 200 + } + + result, err := s.opsRepo.ListSystemLogs(ctx, filter) + if err != nil { + return nil, infraerrors.InternalServer("OPS_SYSTEM_LOG_LIST_FAILED", "Failed to list system logs").WithCause(err) + } + return result, nil +} + +func (s *OpsService) CleanupSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter, operatorID int64) (int64, error) { + if err := s.RequireMonitoringEnabled(ctx); err != nil { + return 0, err + } + if s.opsRepo == nil { + return 0, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available") + } + if operatorID <= 0 { + return 0, infraerrors.BadRequest("OPS_SYSTEM_LOG_CLEANUP_INVALID_OPERATOR", "invalid operator") + } + if filter == nil { + filter = &OpsSystemLogCleanupFilter{} + } + if filter.EndTime != nil && filter.StartTime != nil && filter.StartTime.After(*filter.EndTime) { + return 0, infraerrors.BadRequest("OPS_SYSTEM_LOG_CLEANUP_INVALID_RANGE", "invalid time range") + } + + deletedRows, err := s.opsRepo.DeleteSystemLogs(ctx, filter) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + if strings.Contains(strings.ToLower(err.Error()), "requires at least one filter") { + return 0, infraerrors.BadRequest("OPS_SYSTEM_LOG_CLEANUP_FILTER_REQUIRED", "cleanup requires at least one filter condition") + } + return 0, infraerrors.InternalServer("OPS_SYSTEM_LOG_CLEANUP_FAILED", "Failed to cleanup system logs").WithCause(err) + } + + if auditErr := s.opsRepo.InsertSystemLogCleanupAudit(ctx, &OpsSystemLogCleanupAudit{ + CreatedAt: time.Now().UTC(), + OperatorID: operatorID, + Conditions: marshalSystemLogCleanupConditions(filter), + DeletedRows: deletedRows, + }); auditErr != nil { + // 审计失败不影响主流程,避免运维清理被阻塞。 + log.Printf("[OpsSystemLog] cleanup audit failed: %v", auditErr) + } + return deletedRows, nil +} + +func marshalSystemLogCleanupConditions(filter *OpsSystemLogCleanupFilter) string { + if filter == nil { + return "{}" + } + payload := map[string]any{ + "level": strings.TrimSpace(filter.Level), + "component": strings.TrimSpace(filter.Component), + "request_id": strings.TrimSpace(filter.RequestID), + "client_request_id": strings.TrimSpace(filter.ClientRequestID), + "platform": strings.TrimSpace(filter.Platform), + "model": strings.TrimSpace(filter.Model), + "query": strings.TrimSpace(filter.Query), + } + if filter.UserID != nil { + payload["user_id"] = *filter.UserID + } + if filter.AccountID != nil { + payload["account_id"] = *filter.AccountID + } + if filter.StartTime != nil && !filter.StartTime.IsZero() { + payload["start_time"] = filter.StartTime.UTC().Format(time.RFC3339Nano) + } + if filter.EndTime != nil && !filter.EndTime.IsZero() { + payload["end_time"] = filter.EndTime.UTC().Format(time.RFC3339Nano) + } + raw, err := json.Marshal(payload) + if err != nil { + return "{}" + } + return string(raw) +} + +func (s *OpsService) GetSystemLogSinkHealth() OpsSystemLogSinkHealth { + if s == nil || s.systemLogSink == nil { + return OpsSystemLogSinkHealth{} + } + return s.systemLogSink.Health() +} diff --git a/backend/internal/service/ops_system_log_service_test.go b/backend/internal/service/ops_system_log_service_test.go new file mode 100644 index 00000000..cc9ddefe --- /dev/null +++ b/backend/internal/service/ops_system_log_service_test.go @@ -0,0 +1,243 @@ +package service + +import ( + "context" + "database/sql" + "errors" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +func TestOpsServiceListSystemLogs_DefaultClampAndSuccess(t *testing.T) { + var gotFilter *OpsSystemLogFilter + repo := &opsRepoMock{ + ListSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) { + gotFilter = filter + return &OpsSystemLogList{ + Logs: []*OpsSystemLog{{ID: 1, Level: "warn", Message: "x"}}, + Total: 1, + Page: filter.Page, + PageSize: filter.PageSize, + }, nil + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + + out, err := svc.ListSystemLogs(context.Background(), &OpsSystemLogFilter{ + Page: 0, + PageSize: 999, + }) + if err != nil { + t.Fatalf("ListSystemLogs() error: %v", err) + } + if gotFilter == nil { + t.Fatalf("expected repository to receive filter") + } + if gotFilter.Page != 1 || gotFilter.PageSize != 200 { + t.Fatalf("filter normalized unexpectedly: page=%d pageSize=%d", gotFilter.Page, gotFilter.PageSize) + } + if out.Total != 1 || len(out.Logs) != 1 { + t.Fatalf("unexpected result: %+v", out) + } +} + +func TestOpsServiceListSystemLogs_MonitoringDisabled(t *testing.T) { + svc := NewOpsService( + &opsRepoMock{}, + nil, + &config.Config{Ops: config.OpsConfig{Enabled: false}}, + nil, nil, nil, nil, nil, nil, nil, nil, + ) + _, err := svc.ListSystemLogs(context.Background(), &OpsSystemLogFilter{}) + if err == nil { + t.Fatalf("expected disabled error") + } +} + +func TestOpsServiceListSystemLogs_NilRepoReturnsEmpty(t *testing.T) { + svc := NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + out, err := svc.ListSystemLogs(context.Background(), nil) + if err != nil { + t.Fatalf("ListSystemLogs() error: %v", err) + } + if out == nil || out.Page != 1 || out.PageSize != 50 || out.Total != 0 || len(out.Logs) != 0 { + t.Fatalf("unexpected nil-repo result: %+v", out) + } +} + +func TestOpsServiceListSystemLogs_RepoErrorMapped(t *testing.T) { + repo := &opsRepoMock{ + ListSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) { + return nil, errors.New("db down") + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + _, err := svc.ListSystemLogs(context.Background(), &OpsSystemLogFilter{}) + if err == nil { + t.Fatalf("expected mapped internal error") + } + if !strings.Contains(err.Error(), "OPS_SYSTEM_LOG_LIST_FAILED") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestOpsServiceCleanupSystemLogs_SuccessAndAudit(t *testing.T) { + var audit *OpsSystemLogCleanupAudit + repo := &opsRepoMock{ + DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 3, nil + }, + InsertSystemLogCleanupAuditFn: func(ctx context.Context, input *OpsSystemLogCleanupAudit) error { + audit = input + return nil + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + userID := int64(7) + now := time.Now().UTC() + filter := &OpsSystemLogCleanupFilter{ + StartTime: &now, + Level: "warn", + RequestID: "req-1", + ClientRequestID: "creq-1", + UserID: &userID, + Query: "timeout", + } + + deleted, err := svc.CleanupSystemLogs(context.Background(), filter, 99) + if err != nil { + t.Fatalf("CleanupSystemLogs() error: %v", err) + } + if deleted != 3 { + t.Fatalf("deleted=%d, want 3", deleted) + } + if audit == nil { + t.Fatalf("expected cleanup audit") + } + if !strings.Contains(audit.Conditions, `"client_request_id":"creq-1"`) { + t.Fatalf("audit conditions should include client_request_id: %s", audit.Conditions) + } + if !strings.Contains(audit.Conditions, `"user_id":7`) { + t.Fatalf("audit conditions should include user_id: %s", audit.Conditions) + } +} + +func TestOpsServiceCleanupSystemLogs_RepoUnavailableAndInvalidOperator(t *testing.T) { + svc := NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + if _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{RequestID: "r"}, 1); err == nil { + t.Fatalf("expected repo unavailable error") + } + + svc = NewOpsService(&opsRepoMock{}, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + if _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{RequestID: "r"}, 0); err == nil { + t.Fatalf("expected invalid operator error") + } +} + +func TestOpsServiceCleanupSystemLogs_FilterRequired(t *testing.T) { + repo := &opsRepoMock{ + DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 0, errors.New("cleanup requires at least one filter condition") + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{}, 1) + if err == nil { + t.Fatalf("expected filter required error") + } + if !strings.Contains(strings.ToLower(err.Error()), "filter") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestOpsServiceCleanupSystemLogs_InvalidRange(t *testing.T) { + repo := &opsRepoMock{} + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + start := time.Now().UTC() + end := start.Add(-time.Hour) + _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{ + StartTime: &start, + EndTime: &end, + }, 1) + if err == nil { + t.Fatalf("expected invalid range error") + } +} + +func TestOpsServiceCleanupSystemLogs_NoRowsAndInternalError(t *testing.T) { + repo := &opsRepoMock{ + DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 0, sql.ErrNoRows + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + deleted, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{ + RequestID: "req-1", + }, 1) + if err != nil || deleted != 0 { + t.Fatalf("expected no rows shortcut, deleted=%d err=%v", deleted, err) + } + + repo.DeleteSystemLogsFn = func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 0, errors.New("boom") + } + if _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{ + RequestID: "req-1", + }, 1); err == nil { + t.Fatalf("expected internal cleanup error") + } +} + +func TestOpsServiceCleanupSystemLogs_AuditFailureIgnored(t *testing.T) { + repo := &opsRepoMock{ + DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 5, nil + }, + InsertSystemLogCleanupAuditFn: func(ctx context.Context, input *OpsSystemLogCleanupAudit) error { + return errors.New("audit down") + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + deleted, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{ + RequestID: "r1", + }, 1) + if err != nil || deleted != 5 { + t.Fatalf("audit failure should not break cleanup, deleted=%d err=%v", deleted, err) + } +} + +func TestMarshalSystemLogCleanupConditions_NilAndMarshalError(t *testing.T) { + if got := marshalSystemLogCleanupConditions(nil); got != "{}" { + t.Fatalf("nil filter should return {}, got %s", got) + } + + now := time.Now().UTC() + userID := int64(1) + filter := &OpsSystemLogCleanupFilter{ + StartTime: &now, + EndTime: &now, + UserID: &userID, + } + got := marshalSystemLogCleanupConditions(filter) + if !strings.Contains(got, `"start_time"`) || !strings.Contains(got, `"user_id":1`) { + t.Fatalf("unexpected marshal payload: %s", got) + } +} + +func TestOpsServiceGetSystemLogSinkHealth(t *testing.T) { + svc := NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + health := svc.GetSystemLogSinkHealth() + if health.QueueCapacity != 0 || health.QueueDepth != 0 { + t.Fatalf("unexpected health for nil sink: %+v", health) + } + + sink := NewOpsSystemLogSink(&opsRepoMock{}) + svc = NewOpsService(&opsRepoMock{}, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink) + health = svc.GetSystemLogSinkHealth() + if health.QueueCapacity <= 0 { + t.Fatalf("expected non-zero queue capacity: %+v", health) + } +} diff --git a/backend/internal/service/ops_system_log_sink.go b/backend/internal/service/ops_system_log_sink.go new file mode 100644 index 00000000..c50a30d5 --- /dev/null +++ b/backend/internal/service/ops_system_log_sink.go @@ -0,0 +1,335 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" +) + +type OpsSystemLogSinkHealth struct { + QueueDepth int64 `json:"queue_depth"` + QueueCapacity int64 `json:"queue_capacity"` + DroppedCount uint64 `json:"dropped_count"` + WriteFailed uint64 `json:"write_failed_count"` + WrittenCount uint64 `json:"written_count"` + AvgWriteDelayMs uint64 `json:"avg_write_delay_ms"` + LastError string `json:"last_error"` +} + +type OpsSystemLogSink struct { + opsRepo OpsRepository + + queue chan *logger.LogEvent + + batchSize int + flushInterval time.Duration + + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + droppedCount uint64 + writeFailed uint64 + writtenCount uint64 + totalDelayNs uint64 + + lastError atomic.Value +} + +func NewOpsSystemLogSink(opsRepo OpsRepository) *OpsSystemLogSink { + ctx, cancel := context.WithCancel(context.Background()) + s := &OpsSystemLogSink{ + opsRepo: opsRepo, + queue: make(chan *logger.LogEvent, 5000), + batchSize: 200, + flushInterval: time.Second, + ctx: ctx, + cancel: cancel, + } + s.lastError.Store("") + return s +} + +func (s *OpsSystemLogSink) Start() { + if s == nil || s.opsRepo == nil { + return + } + s.wg.Add(1) + go s.run() +} + +func (s *OpsSystemLogSink) Stop() { + if s == nil { + return + } + s.cancel() + s.wg.Wait() +} + +func (s *OpsSystemLogSink) WriteLogEvent(event *logger.LogEvent) { + if s == nil || event == nil || !s.shouldIndex(event) { + return + } + if s.ctx != nil { + select { + case <-s.ctx.Done(): + return + default: + } + } + + select { + case s.queue <- event: + default: + atomic.AddUint64(&s.droppedCount, 1) + } +} + +func (s *OpsSystemLogSink) shouldIndex(event *logger.LogEvent) bool { + level := strings.ToLower(strings.TrimSpace(event.Level)) + switch level { + case "warn", "warning", "error", "fatal", "panic", "dpanic": + return true + } + + component := strings.ToLower(strings.TrimSpace(event.Component)) + // zap 的 LoggerName 往往为空或不等于业务组件名;业务组件名通常以字段 component 透传。 + if event.Fields != nil { + if fc := strings.ToLower(strings.TrimSpace(asString(event.Fields["component"]))); fc != "" { + component = fc + } + } + if strings.Contains(component, "http.access") { + return true + } + if strings.Contains(component, "audit") { + return true + } + return false +} + +func (s *OpsSystemLogSink) run() { + defer s.wg.Done() + + ticker := time.NewTicker(s.flushInterval) + defer ticker.Stop() + + batch := make([]*logger.LogEvent, 0, s.batchSize) + flush := func(baseCtx context.Context) { + if len(batch) == 0 { + return + } + started := time.Now() + inserted, err := s.flushBatch(baseCtx, batch) + delay := time.Since(started) + if err != nil { + atomic.AddUint64(&s.writeFailed, uint64(len(batch))) + s.lastError.Store(err.Error()) + _, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"ops system log sink flush failed\" err=%v batch=%d\n", + time.Now().Format(time.RFC3339Nano), err, len(batch), + ) + } else { + atomic.AddUint64(&s.writtenCount, uint64(inserted)) + atomic.AddUint64(&s.totalDelayNs, uint64(delay.Nanoseconds())) + s.lastError.Store("") + } + batch = batch[:0] + } + drainAndFlush := func() { + for { + select { + case item := <-s.queue: + if item == nil { + continue + } + batch = append(batch, item) + if len(batch) >= s.batchSize { + flush(context.Background()) + } + default: + flush(context.Background()) + return + } + } + } + + for { + select { + case <-s.ctx.Done(): + drainAndFlush() + return + case item := <-s.queue: + if item == nil { + continue + } + batch = append(batch, item) + if len(batch) >= s.batchSize { + flush(s.ctx) + } + case <-ticker.C: + flush(s.ctx) + } + } +} + +func (s *OpsSystemLogSink) flushBatch(baseCtx context.Context, batch []*logger.LogEvent) (int, error) { + inputs := make([]*OpsInsertSystemLogInput, 0, len(batch)) + for _, event := range batch { + if event == nil { + continue + } + createdAt := event.Time.UTC() + if createdAt.IsZero() { + createdAt = time.Now().UTC() + } + + fields := copyMap(event.Fields) + requestID := asString(fields["request_id"]) + clientRequestID := asString(fields["client_request_id"]) + platform := asString(fields["platform"]) + model := asString(fields["model"]) + component := strings.TrimSpace(event.Component) + if fieldComponent := asString(fields["component"]); fieldComponent != "" { + component = fieldComponent + } + if component == "" { + component = "app" + } + + userID := asInt64Ptr(fields["user_id"]) + accountID := asInt64Ptr(fields["account_id"]) + + // 统一脱敏后写入索引。 + message := logredact.RedactText(strings.TrimSpace(event.Message)) + redactedExtra := logredact.RedactMap(fields) + extraJSONBytes, _ := json.Marshal(redactedExtra) + extraJSON := string(extraJSONBytes) + if strings.TrimSpace(extraJSON) == "" { + extraJSON = "{}" + } + + inputs = append(inputs, &OpsInsertSystemLogInput{ + CreatedAt: createdAt, + Level: strings.ToLower(strings.TrimSpace(event.Level)), + Component: component, + Message: message, + RequestID: requestID, + ClientRequestID: clientRequestID, + UserID: userID, + AccountID: accountID, + Platform: platform, + Model: model, + ExtraJSON: extraJSON, + }) + } + + if len(inputs) == 0 { + return 0, nil + } + if baseCtx == nil || baseCtx.Err() != nil { + baseCtx = context.Background() + } + ctx, cancel := context.WithTimeout(baseCtx, 5*time.Second) + defer cancel() + inserted, err := s.opsRepo.BatchInsertSystemLogs(ctx, inputs) + if err != nil { + return 0, err + } + return int(inserted), nil +} + +func (s *OpsSystemLogSink) Health() OpsSystemLogSinkHealth { + if s == nil { + return OpsSystemLogSinkHealth{} + } + written := atomic.LoadUint64(&s.writtenCount) + totalDelay := atomic.LoadUint64(&s.totalDelayNs) + var avgDelay uint64 + if written > 0 { + avgDelay = (totalDelay / written) / uint64(time.Millisecond) + } + + lastErr, _ := s.lastError.Load().(string) + return OpsSystemLogSinkHealth{ + QueueDepth: int64(len(s.queue)), + QueueCapacity: int64(cap(s.queue)), + DroppedCount: atomic.LoadUint64(&s.droppedCount), + WriteFailed: atomic.LoadUint64(&s.writeFailed), + WrittenCount: written, + AvgWriteDelayMs: avgDelay, + LastError: strings.TrimSpace(lastErr), + } +} + +func copyMap(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func asString(v any) string { + switch t := v.(type) { + case string: + return strings.TrimSpace(t) + case fmt.Stringer: + return strings.TrimSpace(t.String()) + default: + return "" + } +} + +func asInt64Ptr(v any) *int64 { + switch t := v.(type) { + case int: + n := int64(t) + if n <= 0 { + return nil + } + return &n + case int64: + n := t + if n <= 0 { + return nil + } + return &n + case float64: + n := int64(t) + if n <= 0 { + return nil + } + return &n + case json.Number: + if n, err := t.Int64(); err == nil { + if n <= 0 { + return nil + } + return &n + } + case string: + raw := strings.TrimSpace(t) + if raw == "" { + return nil + } + if n, err := strconv.ParseInt(raw, 10, 64); err == nil { + if n <= 0 { + return nil + } + return &n + } + } + return nil +} diff --git a/backend/internal/service/ops_system_log_sink_test.go b/backend/internal/service/ops_system_log_sink_test.go new file mode 100644 index 00000000..12a2ec0c --- /dev/null +++ b/backend/internal/service/ops_system_log_sink_test.go @@ -0,0 +1,313 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +func TestOpsSystemLogSink_ShouldIndex(t *testing.T) { + sink := &OpsSystemLogSink{} + + cases := []struct { + name string + event *logger.LogEvent + want bool + }{ + { + name: "warn level", + event: &logger.LogEvent{Level: "warn", Component: "app"}, + want: true, + }, + { + name: "error level", + event: &logger.LogEvent{Level: "error", Component: "app"}, + want: true, + }, + { + name: "access component", + event: &logger.LogEvent{Level: "info", Component: "http.access"}, + want: true, + }, + { + name: "access component from fields (real zap path)", + event: &logger.LogEvent{ + Level: "info", + Component: "", + Fields: map[string]any{"component": "http.access"}, + }, + want: true, + }, + { + name: "audit component", + event: &logger.LogEvent{Level: "info", Component: "audit.log_config_change"}, + want: true, + }, + { + name: "audit component from fields (real zap path)", + event: &logger.LogEvent{ + Level: "info", + Component: "", + Fields: map[string]any{"component": "audit.log_config_change"}, + }, + want: true, + }, + { + name: "plain info", + event: &logger.LogEvent{Level: "info", Component: "app"}, + want: false, + }, + } + + for _, tc := range cases { + if got := sink.shouldIndex(tc.event); got != tc.want { + t.Fatalf("%s: shouldIndex()=%v, want %v", tc.name, got, tc.want) + } + } +} + +func TestOpsSystemLogSink_WriteLogEvent_ShouldDropWhenQueueFull(t *testing.T) { + sink := &OpsSystemLogSink{ + queue: make(chan *logger.LogEvent, 1), + } + + sink.WriteLogEvent(&logger.LogEvent{Level: "warn", Component: "app"}) + sink.WriteLogEvent(&logger.LogEvent{Level: "warn", Component: "app"}) + + if got := len(sink.queue); got != 1 { + t.Fatalf("queue len = %d, want 1", got) + } + if dropped := atomic.LoadUint64(&sink.droppedCount); dropped != 1 { + t.Fatalf("droppedCount = %d, want 1", dropped) + } +} + +func TestOpsSystemLogSink_Health(t *testing.T) { + sink := &OpsSystemLogSink{ + queue: make(chan *logger.LogEvent, 10), + } + sink.lastError.Store("db timeout") + atomic.StoreUint64(&sink.droppedCount, 3) + atomic.StoreUint64(&sink.writeFailed, 2) + atomic.StoreUint64(&sink.writtenCount, 5) + atomic.StoreUint64(&sink.totalDelayNs, uint64(5000000)) // 5ms total -> avg 1ms + sink.queue <- &logger.LogEvent{Level: "warn", Component: "app"} + sink.queue <- &logger.LogEvent{Level: "warn", Component: "app"} + + health := sink.Health() + if health.QueueDepth != 2 { + t.Fatalf("queue depth = %d, want 2", health.QueueDepth) + } + if health.QueueCapacity != 10 { + t.Fatalf("queue capacity = %d, want 10", health.QueueCapacity) + } + if health.DroppedCount != 3 { + t.Fatalf("dropped = %d, want 3", health.DroppedCount) + } + if health.WriteFailed != 2 { + t.Fatalf("write failed = %d, want 2", health.WriteFailed) + } + if health.WrittenCount != 5 { + t.Fatalf("written = %d, want 5", health.WrittenCount) + } + if health.AvgWriteDelayMs != 1 { + t.Fatalf("avg delay ms = %d, want 1", health.AvgWriteDelayMs) + } + if health.LastError != "db timeout" { + t.Fatalf("last error = %q, want db timeout", health.LastError) + } +} + +func TestOpsSystemLogSink_StartStopAndFlushSuccess(t *testing.T) { + done := make(chan struct{}, 1) + var captured []*OpsInsertSystemLogInput + repo := &opsRepoMock{ + BatchInsertSystemLogsFn: func(_ context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) { + captured = append(captured, inputs...) + select { + case done <- struct{}{}: + default: + } + return int64(len(inputs)), nil + }, + } + + sink := NewOpsSystemLogSink(repo) + sink.batchSize = 1 + sink.flushInterval = 10 * time.Millisecond + sink.Start() + defer sink.Stop() + + sink.WriteLogEvent(&logger.LogEvent{ + Time: time.Now().UTC(), + Level: "warn", + Component: "http.access", + Message: `authorization="Bearer sk-test-123"`, + Fields: map[string]any{ + "component": "http.access", + "request_id": "req-1", + "client_request_id": "creq-1", + "user_id": "12", + "account_id": json.Number("34"), + "platform": "openai", + "model": "gpt-5", + }, + }) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for sink flush") + } + + if len(captured) != 1 { + t.Fatalf("captured len = %d, want 1", len(captured)) + } + item := captured[0] + if item.RequestID != "req-1" || item.ClientRequestID != "creq-1" { + t.Fatalf("unexpected request ids: %+v", item) + } + if item.UserID == nil || *item.UserID != 12 { + t.Fatalf("unexpected user_id: %+v", item.UserID) + } + if item.AccountID == nil || *item.AccountID != 34 { + t.Fatalf("unexpected account_id: %+v", item.AccountID) + } + if strings.TrimSpace(item.Message) == "" { + t.Fatalf("message should not be empty") + } + health := sink.Health() + if health.WrittenCount == 0 { + t.Fatalf("written_count should be >0") + } +} + +func TestOpsSystemLogSink_FlushFailureUpdatesHealth(t *testing.T) { + repo := &opsRepoMock{ + BatchInsertSystemLogsFn: func(_ context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) { + return 0, errors.New("db unavailable") + }, + } + sink := NewOpsSystemLogSink(repo) + sink.batchSize = 1 + sink.flushInterval = 10 * time.Millisecond + sink.Start() + defer sink.Stop() + + sink.WriteLogEvent(&logger.LogEvent{ + Time: time.Now().UTC(), + Level: "warn", + Component: "app", + Message: "boom", + Fields: map[string]any{}, + }) + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + health := sink.Health() + if health.WriteFailed > 0 { + if !strings.Contains(health.LastError, "db unavailable") { + t.Fatalf("unexpected last error: %s", health.LastError) + } + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("write_failed_count not updated") +} + +func TestOpsSystemLogSink_StopFlushUsesActiveContextAndDrainsQueue(t *testing.T) { + var inserted int64 + var canceledCtxCalls int64 + repo := &opsRepoMock{ + BatchInsertSystemLogsFn: func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) { + if err := ctx.Err(); err != nil { + atomic.AddInt64(&canceledCtxCalls, 1) + return 0, err + } + atomic.AddInt64(&inserted, int64(len(inputs))) + return int64(len(inputs)), nil + }, + } + + sink := NewOpsSystemLogSink(repo) + sink.batchSize = 200 + sink.flushInterval = time.Hour + sink.Start() + + sink.WriteLogEvent(&logger.LogEvent{ + Time: time.Now().UTC(), + Level: "warn", + Component: "app", + Message: "pending-on-shutdown", + Fields: map[string]any{"component": "http.access"}, + }) + + sink.Stop() + + if got := atomic.LoadInt64(&inserted); got != 1 { + t.Fatalf("inserted = %d, want 1", got) + } + if got := atomic.LoadInt64(&canceledCtxCalls); got != 0 { + t.Fatalf("canceled ctx calls = %d, want 0", got) + } + health := sink.Health() + if health.WrittenCount != 1 { + t.Fatalf("written_count = %d, want 1", health.WrittenCount) + } +} + +type stringerValue string + +func (s stringerValue) String() string { return string(s) } + +func TestOpsSystemLogSink_HelperFunctions(t *testing.T) { + src := map[string]any{"a": 1} + cloned := copyMap(src) + src["a"] = 2 + v, ok := cloned["a"].(int) + if !ok || v != 1 { + t.Fatalf("copyMap should create copy") + } + if got := asString(stringerValue(" hello ")); got != "hello" { + t.Fatalf("asString stringer = %q", got) + } + if got := asString(fmt.Errorf("x")); got != "" { + t.Fatalf("asString error should be empty, got %q", got) + } + if got := asString(123); got != "" { + t.Fatalf("asString non-string should be empty, got %q", got) + } + + cases := []struct { + in any + want int64 + ok bool + }{ + {in: 5, want: 5, ok: true}, + {in: int64(6), want: 6, ok: true}, + {in: float64(7), want: 7, ok: true}, + {in: json.Number("8"), want: 8, ok: true}, + {in: "9", want: 9, ok: true}, + {in: "0", ok: false}, + {in: -1, ok: false}, + {in: "abc", ok: false}, + } + for _, tc := range cases { + got := asInt64Ptr(tc.in) + if tc.ok { + if got == nil || *got != tc.want { + t.Fatalf("asInt64Ptr(%v) = %+v, want %d", tc.in, got, tc.want) + } + } else if got != nil { + t.Fatalf("asInt64Ptr(%v) should be nil, got %d", tc.in, *got) + } + } +} diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go index 96bcc9fe..21e09c43 100644 --- a/backend/internal/service/ops_upstream_context.go +++ b/backend/internal/service/ops_upstream_context.go @@ -20,8 +20,39 @@ const ( // retry the specific upstream attempt (not just the client request). // This value is sanitized+trimmed before being persisted. OpsUpstreamRequestBodyKey = "ops_upstream_request_body" + + // Optional stage latencies (milliseconds) for troubleshooting and alerting. + OpsAuthLatencyMsKey = "ops_auth_latency_ms" + OpsRoutingLatencyMsKey = "ops_routing_latency_ms" + OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms" + OpsResponseLatencyMsKey = "ops_response_latency_ms" + OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms" + // OpenAI WS 关键观测字段 + OpsOpenAIWSQueueWaitMsKey = "ops_openai_ws_queue_wait_ms" + OpsOpenAIWSConnPickMsKey = "ops_openai_ws_conn_pick_ms" + OpsOpenAIWSConnReusedKey = "ops_openai_ws_conn_reused" + OpsOpenAIWSConnIDKey = "ops_openai_ws_conn_id" + + // OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。 + // ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。 + OpsSkipPassthroughKey = "ops_skip_passthrough" ) +func setOpsUpstreamRequestBody(c *gin.Context, body []byte) { + if c == nil || len(body) == 0 { + return + } + // 热路径避免 string(body) 额外分配,按需在落库前再转换。 + c.Set(OpsUpstreamRequestBodyKey, body) +} + +func SetOpsLatencyMs(c *gin.Context, key string, value int64) { + if c == nil || strings.TrimSpace(key) == "" || value < 0 { + return + } + c.Set(key, value) +} + func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) { if c == nil { return @@ -42,6 +73,10 @@ func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage type OpsUpstreamErrorEvent struct { AtUnixMs int64 `json:"at_unix_ms,omitempty"` + // Passthrough 表示本次请求是否命中“原样透传(仅替换认证)”分支。 + // 该字段用于排障与灰度评估;存入 JSON,不涉及 DB schema 变更。 + Passthrough bool `json:"passthrough,omitempty"` + // Context Platform string `json:"platform,omitempty"` AccountID int64 `json:"account_id,omitempty"` @@ -87,8 +122,11 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { // stored it on the context, attach it so ops can retry this specific attempt. if ev.UpstreamRequestBody == "" { if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok { - if s, ok := v.(string); ok { - ev.UpstreamRequestBody = strings.TrimSpace(s) + switch raw := v.(type) { + case string: + ev.UpstreamRequestBody = strings.TrimSpace(raw) + case []byte: + ev.UpstreamRequestBody = strings.TrimSpace(string(raw)) } } } @@ -103,6 +141,37 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { evCopy := ev existing = append(existing, &evCopy) c.Set(OpsUpstreamErrorsKey, existing) + + checkSkipMonitoringForUpstreamEvent(c, &evCopy) +} + +// checkSkipMonitoringForUpstreamEvent checks whether the upstream error event +// matches a passthrough rule with skip_monitoring=true and, if so, sets the +// OpsSkipPassthroughKey on the context. This ensures intermediate retry / +// failover errors (which never go through the final applyErrorPassthroughRule +// path) can still suppress ops_error_logs recording. +func checkSkipMonitoringForUpstreamEvent(c *gin.Context, ev *OpsUpstreamErrorEvent) { + if ev.UpstreamStatusCode == 0 { + return + } + + svc := getBoundErrorPassthroughService(c) + if svc == nil { + return + } + + // Use the best available body representation for keyword matching. + // Even when body is empty, MatchRule can still match rules that only + // specify ErrorCodes (no Keywords), so we always call it. + body := ev.Detail + if body == "" { + body = ev.Message + } + + rule := svc.MatchRule(ev.Platform, ev.UpstreamStatusCode, []byte(body)) + if rule != nil && rule.SkipMonitoring { + c.Set(OpsSkipPassthroughKey, true) + } } func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string { diff --git a/backend/internal/service/ops_upstream_context_test.go b/backend/internal/service/ops_upstream_context_test.go new file mode 100644 index 00000000..50ceaa0e --- /dev/null +++ b/backend/internal/service/ops_upstream_context_test.go @@ -0,0 +1,47 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + setOpsUpstreamRequestBody(c, []byte(`{"model":"gpt-5"}`)) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Kind: "http_error", + Message: "upstream failed", + }) + + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, `{"model":"gpt-5"}`, events[0].UpstreamRequestBody) +} + +func TestAppendOpsUpstreamError_UsesRequestBodyStringFromContext(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + c.Set(OpsUpstreamRequestBodyKey, `{"model":"gpt-4"}`) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Kind: "request_error", + Message: "dial timeout", + }) + + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, `{"model":"gpt-4"}`, events[0].UpstreamRequestBody) +} diff --git a/backend/internal/service/parse_integral_number_unit.go b/backend/internal/service/parse_integral_number_unit.go new file mode 100644 index 00000000..c9c617b1 --- /dev/null +++ b/backend/internal/service/parse_integral_number_unit.go @@ -0,0 +1,51 @@ +//go:build unit + +package service + +import ( + "encoding/json" + "math" +) + +// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。 +// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。 +// +// 说明: +// - 该函数当前仅用于 unit 测试中的 map-based 解析逻辑验证,因此放在 unit build tag 下, +// 避免在默认构建中触发 unused lint。 +func parseIntegralNumber(raw any) (int, bool) { + switch v := raw.(type) { + case float64: + if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) { + return 0, false + } + if v > float64(math.MaxInt) || v < float64(math.MinInt) { + return 0, false + } + return int(v), true + case int: + return v, true + case int8: + return int(v), true + case int16: + return int(v), true + case int32: + return int(v), true + case int64: + if v > int64(math.MaxInt) || v < int64(math.MinInt) { + return 0, false + } + return int(v), true + case json.Number: + i64, err := v.Int64() + if err != nil { + return 0, false + } + if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) { + return 0, false + } + return int(i64), true + default: + return 0, false + } +} diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index d8db0d67..41e8b5eb 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -6,7 +6,6 @@ import ( "encoding/hex" "encoding/json" "fmt" - "log" "os" "path/filepath" "regexp" @@ -15,8 +14,10 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "go.uber.org/zap" ) var ( @@ -27,14 +28,15 @@ var ( // LiteLLMModelPricing LiteLLM价格数据结构 // 只保留我们需要的字段,使用指针来处理可能缺失的值 type LiteLLMModelPricing struct { - InputCostPerToken float64 `json:"input_cost_per_token"` - OutputCostPerToken float64 `json:"output_cost_per_token"` - CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` - CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` - LiteLLMProvider string `json:"litellm_provider"` - Mode string `json:"mode"` - SupportsPromptCaching bool `json:"supports_prompt_caching"` - OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 + InputCostPerToken float64 `json:"input_cost_per_token"` + OutputCostPerToken float64 `json:"output_cost_per_token"` + CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` + CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"` + CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` + LiteLLMProvider string `json:"litellm_provider"` + Mode string `json:"mode"` + SupportsPromptCaching bool `json:"supports_prompt_caching"` + OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 } // PricingRemoteClient 远程价格数据获取接口 @@ -45,14 +47,15 @@ type PricingRemoteClient interface { // LiteLLMRawEntry 用于解析原始JSON数据 type LiteLLMRawEntry struct { - InputCostPerToken *float64 `json:"input_cost_per_token"` - OutputCostPerToken *float64 `json:"output_cost_per_token"` - CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` - CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` - LiteLLMProvider string `json:"litellm_provider"` - Mode string `json:"mode"` - SupportsPromptCaching bool `json:"supports_prompt_caching"` - OutputCostPerImage *float64 `json:"output_cost_per_image"` + InputCostPerToken *float64 `json:"input_cost_per_token"` + OutputCostPerToken *float64 `json:"output_cost_per_token"` + CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` + CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"` + CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` + LiteLLMProvider string `json:"litellm_provider"` + Mode string `json:"mode"` + SupportsPromptCaching bool `json:"supports_prompt_caching"` + OutputCostPerImage *float64 `json:"output_cost_per_image"` } // PricingService 动态价格服务 @@ -84,12 +87,12 @@ func NewPricingService(cfg *config.Config, remoteClient PricingRemoteClient) *Pr func (s *PricingService) Initialize() error { // 确保数据目录存在 if err := os.MkdirAll(s.cfg.Pricing.DataDir, 0755); err != nil { - log.Printf("[Pricing] Failed to create data directory: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to create data directory: %v", err) } // 首次加载价格数据 if err := s.checkAndUpdatePricing(); err != nil { - log.Printf("[Pricing] Initial load failed, using fallback: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Initial load failed, using fallback: %v", err) if err := s.useFallbackPricing(); err != nil { return fmt.Errorf("failed to load pricing data: %w", err) } @@ -98,7 +101,7 @@ func (s *PricingService) Initialize() error { // 启动定时更新 s.startUpdateScheduler() - log.Printf("[Pricing] Service initialized with %d models", len(s.pricingData)) + logger.LegacyPrintf("service.pricing", "[Pricing] Service initialized with %d models", len(s.pricingData)) return nil } @@ -106,7 +109,7 @@ func (s *PricingService) Initialize() error { func (s *PricingService) Stop() { close(s.stopCh) s.wg.Wait() - log.Println("[Pricing] Service stopped") + logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Service stopped") } // startUpdateScheduler 启动定时更新调度器 @@ -127,7 +130,7 @@ func (s *PricingService) startUpdateScheduler() { select { case <-ticker.C: if err := s.syncWithRemote(); err != nil { - log.Printf("[Pricing] Sync failed: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Sync failed: %v", err) } case <-s.stopCh: return @@ -135,7 +138,7 @@ func (s *PricingService) startUpdateScheduler() { } }() - log.Printf("[Pricing] Update scheduler started (check every %v)", hashInterval) + logger.LegacyPrintf("service.pricing", "[Pricing] Update scheduler started (check every %v)", hashInterval) } // checkAndUpdatePricing 检查并更新价格数据 @@ -144,7 +147,7 @@ func (s *PricingService) checkAndUpdatePricing() error { // 检查本地文件是否存在 if _, err := os.Stat(pricingFile); os.IsNotExist(err) { - log.Println("[Pricing] Local pricing file not found, downloading...") + logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Local pricing file not found, downloading...") return s.downloadPricingData() } @@ -158,9 +161,9 @@ func (s *PricingService) checkAndUpdatePricing() error { maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour if fileAge > maxAge { - log.Printf("[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour)) + logger.LegacyPrintf("service.pricing", "[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour)) if err := s.downloadPricingData(); err != nil { - log.Printf("[Pricing] Download failed, using existing file: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Download failed, using existing file: %v", err) } } @@ -175,7 +178,7 @@ func (s *PricingService) syncWithRemote() error { // 计算本地文件哈希 localHash, err := s.computeFileHash(pricingFile) if err != nil { - log.Printf("[Pricing] Failed to compute local hash: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to compute local hash: %v", err) return s.downloadPricingData() } @@ -183,15 +186,15 @@ func (s *PricingService) syncWithRemote() error { if s.cfg.Pricing.HashURL != "" { remoteHash, err := s.fetchRemoteHash() if err != nil { - log.Printf("[Pricing] Failed to fetch remote hash: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash: %v", err) return nil // 哈希获取失败不影响正常使用 } if remoteHash != localHash { - log.Println("[Pricing] Remote hash differs, downloading new version...") + logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Remote hash differs, downloading new version...") return s.downloadPricingData() } - log.Println("[Pricing] Hash check passed, no update needed") + logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Hash check passed, no update needed") return nil } @@ -205,7 +208,7 @@ func (s *PricingService) syncWithRemote() error { maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour if fileAge > maxAge { - log.Printf("[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour)) + logger.LegacyPrintf("service.pricing", "[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour)) return s.downloadPricingData() } @@ -218,7 +221,7 @@ func (s *PricingService) downloadPricingData() error { if err != nil { return err } - log.Printf("[Pricing] Downloading from %s", remoteURL) + logger.LegacyPrintf("service.pricing", "[Pricing] Downloading from %s", remoteURL) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -252,7 +255,7 @@ func (s *PricingService) downloadPricingData() error { // 保存到本地文件 pricingFile := s.getPricingFilePath() if err := os.WriteFile(pricingFile, body, 0644); err != nil { - log.Printf("[Pricing] Failed to save file: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save file: %v", err) } // 保存哈希 @@ -260,7 +263,7 @@ func (s *PricingService) downloadPricingData() error { hashStr := hex.EncodeToString(hash[:]) hashFile := s.getHashFilePath() if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil { - log.Printf("[Pricing] Failed to save hash: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save hash: %v", err) } // 更新内存数据 @@ -270,7 +273,7 @@ func (s *PricingService) downloadPricingData() error { s.localHash = hashStr s.mu.Unlock() - log.Printf("[Pricing] Downloaded %d models successfully", len(data)) + logger.LegacyPrintf("service.pricing", "[Pricing] Downloaded %d models successfully", len(data)) return nil } @@ -318,6 +321,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel if entry.CacheCreationInputTokenCost != nil { pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost } + if entry.CacheCreationInputTokenCostAbove1hr != nil { + pricing.CacheCreationInputTokenCostAbove1hr = *entry.CacheCreationInputTokenCostAbove1hr + } if entry.CacheReadInputTokenCost != nil { pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost } @@ -329,7 +335,7 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel } if skipped > 0 { - log.Printf("[Pricing] Skipped %d invalid entries", skipped) + logger.LegacyPrintf("service.pricing", "[Pricing] Skipped %d invalid entries", skipped) } if len(result) == 0 { @@ -368,7 +374,7 @@ func (s *PricingService) loadPricingData(filePath string) error { } s.mu.Unlock() - log.Printf("[Pricing] Loaded %d models from %s", len(pricingData), filePath) + logger.LegacyPrintf("service.pricing", "[Pricing] Loaded %d models from %s", len(pricingData), filePath) return nil } @@ -380,7 +386,7 @@ func (s *PricingService) useFallbackPricing() error { return fmt.Errorf("fallback file not found: %s", fallbackFile) } - log.Printf("[Pricing] Using fallback file: %s", fallbackFile) + logger.LegacyPrintf("service.pricing", "[Pricing] Using fallback file: %s", fallbackFile) // 复制到数据目录 data, err := os.ReadFile(fallbackFile) @@ -390,7 +396,7 @@ func (s *PricingService) useFallbackPricing() error { pricingFile := s.getPricingFilePath() if err := os.WriteFile(pricingFile, data, 0644); err != nil { - log.Printf("[Pricing] Failed to copy fallback: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to copy fallback: %v", err) } return s.loadPricingData(fallbackFile) @@ -639,7 +645,7 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { for key, pricing := range s.pricingData { keyLower := strings.ToLower(key) if strings.Contains(keyLower, pattern) { - log.Printf("[Pricing] Fuzzy matched %s -> %s", model, key) + logger.LegacyPrintf("service.pricing", "[Pricing] Fuzzy matched %s -> %s", model, key) return pricing } } @@ -650,24 +656,36 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { // matchOpenAIModel OpenAI 模型回退匹配策略 // 回退顺序: -// 1. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等) -// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) -// 3. gpt-5.3-codex -> gpt-5.2-codex -// 4. 最终回退到 DefaultTestModel (gpt-5.1-codex) +// 1. gpt-5.3-codex-spark* -> gpt-5.1-codex(按业务要求固定计费) +// 2. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等) +// 3. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) +// 4. gpt-5.3-codex -> gpt-5.2-codex +// 5. 最终回退到 DefaultTestModel (gpt-5.1-codex) func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { + if strings.HasPrefix(model, "gpt-5.3-codex-spark") { + if pricing, ok := s.pricingData["gpt-5.1-codex"]; ok { + logger.LegacyPrintf("service.pricing", "[Pricing][SparkBilling] %s -> %s billing", model, "gpt-5.1-codex") + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.1-codex")) + return pricing + } + } + // 尝试的回退变体 variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern) for _, variant := range variants { if pricing, ok := s.pricingData[variant]; ok { - log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, variant) + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, variant)) return pricing } } if strings.HasPrefix(model, "gpt-5.3-codex") { if pricing, ok := s.pricingData["gpt-5.2-codex"]; ok { - log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex") + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex")) return pricing } } @@ -675,7 +693,7 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { // 最终回退到 DefaultTestModel defaultModel := strings.ToLower(openai.DefaultTestModel) if pricing, ok := s.pricingData[defaultModel]; ok { - log.Printf("[Pricing] OpenAI fallback to default model %s -> %s", model, defaultModel) + logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI fallback to default model %s -> %s", model, defaultModel) return pricing } diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go new file mode 100644 index 00000000..127ff342 --- /dev/null +++ b/backend/internal/service/pricing_service_test.go @@ -0,0 +1,53 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetModelPricing_Gpt53CodexSparkUsesGpt51CodexPricing(t *testing.T) { + sparkPricing := &LiteLLMModelPricing{InputCostPerToken: 1} + gpt53Pricing := &LiteLLMModelPricing{InputCostPerToken: 9} + + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.1-codex": sparkPricing, + "gpt-5.3": gpt53Pricing, + }, + } + + got := svc.GetModelPricing("gpt-5.3-codex-spark") + require.Same(t, sparkPricing, got) +} + +func TestGetModelPricing_Gpt53CodexFallbackStillUsesGpt52Codex(t *testing.T) { + gpt52CodexPricing := &LiteLLMModelPricing{InputCostPerToken: 2} + + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.2-codex": gpt52CodexPricing, + }, + } + + got := svc.GetModelPricing("gpt-5.3-codex") + require.Same(t, gpt52CodexPricing, got) +} + +func TestGetModelPricing_OpenAIFallbackMatchedLoggedAsInfo(t *testing.T) { + logSink, restore := captureStructuredLog(t) + defer restore() + + gpt52CodexPricing := &LiteLLMModelPricing{InputCostPerToken: 2} + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.2-codex": gpt52CodexPricing, + }, + } + + got := svc.GetModelPricing("gpt-5.3-codex") + require.Same(t, gpt52CodexPricing, got) + + require.True(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "info")) + require.False(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "warn")) +} diff --git a/backend/internal/service/proxy.go b/backend/internal/service/proxy.go index 7eb7728f..fc449091 100644 --- a/backend/internal/service/proxy.go +++ b/backend/internal/service/proxy.go @@ -40,6 +40,11 @@ type ProxyWithAccountCount struct { CountryCode string Region string City string + QualityStatus string + QualityScore *int + QualityGrade string + QualitySummary string + QualityChecked *int64 } type ProxyAccountSummary struct { diff --git a/backend/internal/service/proxy_latency_cache.go b/backend/internal/service/proxy_latency_cache.go index 4a1cc77b..f54bff88 100644 --- a/backend/internal/service/proxy_latency_cache.go +++ b/backend/internal/service/proxy_latency_cache.go @@ -6,15 +6,21 @@ import ( ) type ProxyLatencyInfo struct { - Success bool `json:"success"` - LatencyMs *int64 `json:"latency_ms,omitempty"` - Message string `json:"message,omitempty"` - IPAddress string `json:"ip_address,omitempty"` - Country string `json:"country,omitempty"` - CountryCode string `json:"country_code,omitempty"` - Region string `json:"region,omitempty"` - City string `json:"city,omitempty"` - UpdatedAt time.Time `json:"updated_at"` + Success bool `json:"success"` + LatencyMs *int64 `json:"latency_ms,omitempty"` + Message string `json:"message,omitempty"` + IPAddress string `json:"ip_address,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` + Region string `json:"region,omitempty"` + City string `json:"city,omitempty"` + QualityStatus string `json:"quality_status,omitempty"` + QualityScore *int `json:"quality_score,omitempty"` + QualityGrade string `json:"quality_grade,omitempty"` + QualitySummary string `json:"quality_summary,omitempty"` + QualityCheckedAt *int64 `json:"quality_checked_at,omitempty"` + QualityCFRay string `json:"quality_cf_ray,omitempty"` + UpdatedAt time.Time `json:"updated_at"` } type ProxyLatencyCache interface { diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 47286deb..d4d70536 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // RateLimitService 处理限流和过载状态管理 @@ -33,6 +34,10 @@ type geminiUsageCacheEntry struct { totals GeminiUsageTotals } +type geminiUsageTotalsBatchProvider interface { + GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]GeminiUsageTotals, error) +} + const geminiPrecheckCacheTTL = time.Minute // NewRateLimitService 创建RateLimitService实例 @@ -62,6 +67,32 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali s.tokenCacheInvalidator = invalidator } +// ErrorPolicyResult 表示错误策略检查的结果 +type ErrorPolicyResult int + +const ( + ErrorPolicyNone ErrorPolicyResult = iota // 未命中任何策略,继续默认逻辑 + ErrorPolicySkipped // 自定义错误码开启但未命中,跳过处理 + ErrorPolicyMatched // 自定义错误码命中,应停止调度 + ErrorPolicyTempUnscheduled // 临时不可调度规则命中 +) + +// CheckErrorPolicy 检查自定义错误码和临时不可调度规则。 +// 自定义错误码开启时覆盖后续所有逻辑(包括临时不可调度)。 +func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Account, statusCode int, responseBody []byte) ErrorPolicyResult { + if account.IsCustomErrorCodesEnabled() { + if account.ShouldHandleErrorCode(statusCode) { + return ErrorPolicyMatched + } + slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode) + return ErrorPolicySkipped + } + if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) { + return ErrorPolicyTempUnscheduled + } + return ErrorPolicyNone +} + // HandleUpstreamError 处理上游错误响应,标记账号状态 // 返回是否应该停止该账号的调度 func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { @@ -136,6 +167,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc if upstreamMsg != "" { msg = "Access forbidden (403): " + upstreamMsg } + logger.LegacyPrintf( + "service.ratelimit", + "[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s", + account.ID, + account.Platform, + account.Type, + strings.TrimSpace(headers.Get("x-request-id")), + strings.TrimSpace(headers.Get("cf-ray")), + upstreamMsg, + truncateForLog(responseBody, 1024), + ) s.handleAuthError(ctx, account, msg) shouldDisable = true case 429: @@ -199,7 +241,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, start := geminiDailyWindowStart(now) totals, ok := s.getGeminiUsageTotals(account.ID, start, now) if !ok { - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return true, err } @@ -246,7 +288,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, if limit > 0 { start := now.Truncate(time.Minute) - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return true, err } @@ -276,6 +318,218 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, return true, nil } +// PreCheckUsageBatch performs quota precheck for multiple accounts in one request. +// Returned map value=false means the account should be skipped. +func (s *RateLimitService) PreCheckUsageBatch(ctx context.Context, accounts []*Account, requestedModel string) (map[int64]bool, error) { + result := make(map[int64]bool, len(accounts)) + for _, account := range accounts { + if account == nil { + continue + } + result[account.ID] = true + } + + if len(accounts) == 0 || requestedModel == "" { + return result, nil + } + if s.usageRepo == nil || s.geminiQuotaService == nil { + return result, nil + } + + modelClass := geminiModelClassFromName(requestedModel) + now := time.Now() + dailyStart := geminiDailyWindowStart(now) + minuteStart := now.Truncate(time.Minute) + + type quotaAccount struct { + account *Account + quota GeminiQuota + } + quotaAccounts := make([]quotaAccount, 0, len(accounts)) + for _, account := range accounts { + if account == nil || account.Platform != PlatformGemini { + continue + } + quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account) + if !ok { + continue + } + quotaAccounts = append(quotaAccounts, quotaAccount{ + account: account, + quota: quota, + }) + } + if len(quotaAccounts) == 0 { + return result, nil + } + + // 1) Daily precheck (cached + batch DB fallback) + dailyTotalsByID := make(map[int64]GeminiUsageTotals, len(quotaAccounts)) + dailyMissIDs := make([]int64, 0, len(quotaAccounts)) + for _, item := range quotaAccounts { + limit := geminiDailyLimit(item.quota, modelClass) + if limit <= 0 { + continue + } + accountID := item.account.ID + if totals, ok := s.getGeminiUsageTotals(accountID, dailyStart, now); ok { + dailyTotalsByID[accountID] = totals + continue + } + dailyMissIDs = append(dailyMissIDs, accountID) + } + if len(dailyMissIDs) > 0 { + totalsBatch, err := s.getGeminiUsageTotalsBatch(ctx, dailyMissIDs, dailyStart, now) + if err != nil { + return result, err + } + for _, accountID := range dailyMissIDs { + totals := totalsBatch[accountID] + dailyTotalsByID[accountID] = totals + s.setGeminiUsageTotals(accountID, dailyStart, now, totals) + } + } + for _, item := range quotaAccounts { + limit := geminiDailyLimit(item.quota, modelClass) + if limit <= 0 { + continue + } + accountID := item.account.ID + used := geminiUsedRequests(item.quota, modelClass, dailyTotalsByID[accountID], true) + if used >= limit { + resetAt := geminiDailyResetTime(now) + slog.Info("gemini_precheck_daily_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt) + result[accountID] = false + } + } + + // 2) Minute precheck (batch DB) + minuteIDs := make([]int64, 0, len(quotaAccounts)) + for _, item := range quotaAccounts { + accountID := item.account.ID + if !result[accountID] { + continue + } + if geminiMinuteLimit(item.quota, modelClass) <= 0 { + continue + } + minuteIDs = append(minuteIDs, accountID) + } + if len(minuteIDs) == 0 { + return result, nil + } + + minuteTotalsByID, err := s.getGeminiUsageTotalsBatch(ctx, minuteIDs, minuteStart, now) + if err != nil { + return result, err + } + for _, item := range quotaAccounts { + accountID := item.account.ID + if !result[accountID] { + continue + } + + limit := geminiMinuteLimit(item.quota, modelClass) + if limit <= 0 { + continue + } + + used := geminiUsedRequests(item.quota, modelClass, minuteTotalsByID[accountID], false) + if used >= limit { + resetAt := minuteStart.Add(time.Minute) + slog.Info("gemini_precheck_minute_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt) + result[accountID] = false + } + } + + return result, nil +} + +func (s *RateLimitService) getGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, start, end time.Time) (map[int64]GeminiUsageTotals, error) { + result := make(map[int64]GeminiUsageTotals, len(accountIDs)) + if len(accountIDs) == 0 { + return result, nil + } + + ids := make([]int64, 0, len(accountIDs)) + seen := make(map[int64]struct{}, len(accountIDs)) + for _, accountID := range accountIDs { + if accountID <= 0 { + continue + } + if _, ok := seen[accountID]; ok { + continue + } + seen[accountID] = struct{}{} + ids = append(ids, accountID) + } + if len(ids) == 0 { + return result, nil + } + + if batchReader, ok := s.usageRepo.(geminiUsageTotalsBatchProvider); ok { + stats, err := batchReader.GetGeminiUsageTotalsBatch(ctx, ids, start, end) + if err != nil { + return nil, err + } + for _, accountID := range ids { + result[accountID] = stats[accountID] + } + return result, nil + } + + for _, accountID := range ids { + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, end, 0, 0, accountID, 0, nil, nil, nil) + if err != nil { + return nil, err + } + result[accountID] = geminiAggregateUsage(stats) + } + return result, nil +} + +func geminiDailyLimit(quota GeminiQuota, modelClass geminiModelClass) int64 { + if quota.SharedRPD > 0 { + return quota.SharedRPD + } + switch modelClass { + case geminiModelFlash: + return quota.FlashRPD + default: + return quota.ProRPD + } +} + +func geminiMinuteLimit(quota GeminiQuota, modelClass geminiModelClass) int64 { + if quota.SharedRPM > 0 { + return quota.SharedRPM + } + switch modelClass { + case geminiModelFlash: + return quota.FlashRPM + default: + return quota.ProRPM + } +} + +func geminiUsedRequests(quota GeminiQuota, modelClass geminiModelClass, totals GeminiUsageTotals, daily bool) int64 { + if daily { + if quota.SharedRPD > 0 { + return totals.ProRequests + totals.FlashRequests + } + } else { + if quota.SharedRPM > 0 { + return totals.ProRequests + totals.FlashRequests + } + } + switch modelClass { + case geminiModelFlash: + return totals.FlashRequests + default: + return totals.ProRequests + } +} + func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) { s.usageCacheMu.RLock() defer s.usageCacheMu.RUnlock() @@ -355,10 +609,31 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head } } - // 2. 尝试从响应头解析重置时间(Anthropic) + // 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口 + if result := calculateAnthropic429ResetTime(headers); result != nil { + if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil { + slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) + return + } + + // 更新 session window:优先使用 5h-reset 头精确计算,否则从 resetAt 反推 + windowEnd := result.resetAt + if result.fiveHourReset != nil { + windowEnd = *result.fiveHourReset + } + windowStart := windowEnd.Add(-5 * time.Hour) + if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil { + slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err) + } + + slog.Info("anthropic_account_rate_limited", "account_id", account.ID, "reset_at", result.resetAt, "reset_in", time.Until(result.resetAt).Truncate(time.Second)) + return + } + + // 3. 尝试从响应头解析重置时间(Anthropic 聚合头,向后兼容) resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset") - // 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini) + // 4. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini) if resetTimestamp == "" { switch account.Platform { case PlatformOpenAI: @@ -471,6 +746,112 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim return nil } +// anthropic429Result holds the parsed Anthropic 429 rate-limit information. +type anthropic429Result struct { + resetAt time.Time // The correct reset time to use for SetRateLimited + fiveHourReset *time.Time // 5h window reset timestamp (for session window calculation), nil if not available +} + +// calculateAnthropic429ResetTime parses Anthropic's per-window rate-limit headers +// to determine which window (5h or 7d) actually triggered the 429. +// +// Headers used: +// - anthropic-ratelimit-unified-5h-utilization / anthropic-ratelimit-unified-5h-surpassed-threshold +// - anthropic-ratelimit-unified-5h-reset +// - anthropic-ratelimit-unified-7d-utilization / anthropic-ratelimit-unified-7d-surpassed-threshold +// - anthropic-ratelimit-unified-7d-reset +// +// Returns nil when the per-window headers are absent (caller should fall back to +// the aggregated anthropic-ratelimit-unified-reset header). +func calculateAnthropic429ResetTime(headers http.Header) *anthropic429Result { + reset5hStr := headers.Get("anthropic-ratelimit-unified-5h-reset") + reset7dStr := headers.Get("anthropic-ratelimit-unified-7d-reset") + + if reset5hStr == "" && reset7dStr == "" { + return nil + } + + var reset5h, reset7d *time.Time + if ts, err := strconv.ParseInt(reset5hStr, 10, 64); err == nil { + t := time.Unix(ts, 0) + reset5h = &t + } + if ts, err := strconv.ParseInt(reset7dStr, 10, 64); err == nil { + t := time.Unix(ts, 0) + reset7d = &t + } + + is5hExceeded := isAnthropicWindowExceeded(headers, "5h") + is7dExceeded := isAnthropicWindowExceeded(headers, "7d") + + slog.Info("anthropic_429_window_analysis", + "is_5h_exceeded", is5hExceeded, + "is_7d_exceeded", is7dExceeded, + "reset_5h", reset5hStr, + "reset_7d", reset7dStr, + ) + + // Select the correct reset time based on which window(s) are exceeded. + var chosen *time.Time + switch { + case is5hExceeded && is7dExceeded: + // Both exceeded → prefer 7d (longer cooldown), fall back to 5h + chosen = reset7d + if chosen == nil { + chosen = reset5h + } + case is5hExceeded: + chosen = reset5h + case is7dExceeded: + chosen = reset7d + default: + // Neither flag clearly exceeded — pick the sooner reset as best guess + chosen = pickSooner(reset5h, reset7d) + } + + if chosen == nil { + return nil + } + return &anthropic429Result{resetAt: *chosen, fiveHourReset: reset5h} +} + +// isAnthropicWindowExceeded checks whether a given Anthropic rate-limit window +// (e.g. "5h" or "7d") has been exceeded, using utilization and surpassed-threshold headers. +func isAnthropicWindowExceeded(headers http.Header, window string) bool { + prefix := "anthropic-ratelimit-unified-" + window + "-" + + // Check surpassed-threshold first (most explicit signal) + if st := headers.Get(prefix + "surpassed-threshold"); strings.EqualFold(st, "true") { + return true + } + + // Fall back to utilization >= 1.0 + if utilStr := headers.Get(prefix + "utilization"); utilStr != "" { + if util, err := strconv.ParseFloat(utilStr, 64); err == nil && util >= 1.0-1e-9 { + // Use a small epsilon to handle floating point: treat 0.9999999... as >= 1.0 + return true + } + } + + return false +} + +// pickSooner returns whichever of the two time pointers is earlier. +// If only one is non-nil, it is returned. If both are nil, returns nil. +func pickSooner(a, b *time.Time) *time.Time { + switch { + case a != nil && b != nil: + if a.Before(*b) { + return a + } + return b + case a != nil: + return a + default: + return b + } +} + // parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳 // OpenAI 的 usage_limit_reached 错误格式: // @@ -585,7 +966,19 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) if err := s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID); err != nil { return err } - return s.accountRepo.ClearModelRateLimits(ctx, accountID) + if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil { + return err + } + // 清除限流时一并清理临时不可调度状态,避免周限/窗口重置后仍被本地临时状态阻断。 + if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil { + return err + } + if s.tempUnschedCache != nil { + if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil { + slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err) + } + } + return nil } func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error { @@ -597,6 +990,10 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err) } } + // 同时清除模型级别限流 + if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil { + slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err) + } return nil } diff --git a/backend/internal/service/ratelimit_service_anthropic_test.go b/backend/internal/service/ratelimit_service_anthropic_test.go new file mode 100644 index 00000000..eaeaf30e --- /dev/null +++ b/backend/internal/service/ratelimit_service_anthropic_test.go @@ -0,0 +1,202 @@ +package service + +import ( + "net/http" + "testing" + "time" +) + +func TestCalculateAnthropic429ResetTime_Only5hExceeded(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.02") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.32") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) + + if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) { + t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset) + } +} + +func TestCalculateAnthropic429ResetTime_Only7dExceeded(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.50") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.05") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1771549200) + + // fiveHourReset should still be populated for session window calculation + if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) { + t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset) + } +} + +func TestCalculateAnthropic429ResetTime_BothExceeded(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.10") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.02") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1771549200) +} + +func TestCalculateAnthropic429ResetTime_NoPerWindowHeaders(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + if result != nil { + t.Errorf("expected nil result when no per-window headers, got resetAt=%v", result.resetAt) + } +} + +func TestCalculateAnthropic429ResetTime_NoHeaders(t *testing.T) { + result := calculateAnthropic429ResetTime(http.Header{}) + if result != nil { + t.Errorf("expected nil result for empty headers, got resetAt=%v", result.resetAt) + } +} + +func TestCalculateAnthropic429ResetTime_SurpassedThreshold(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "false") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_UtilizationExactlyOne(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.0") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.5") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_NeitherExceeded_UsesShorter(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.95") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") // sooner + headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.80") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") // later + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_Only5hResetHeader(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.05") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_Only7dResetHeader(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.03") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1771549200) + + if result.fiveHourReset != nil { + t.Errorf("expected fiveHourReset=nil when no 5h headers, got %v", result.fiveHourReset) + } +} + +func TestIsAnthropicWindowExceeded(t *testing.T) { + tests := []struct { + name string + headers http.Header + window string + expected bool + }{ + { + name: "utilization above 1.0", + headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.02"), + window: "5h", + expected: true, + }, + { + name: "utilization exactly 1.0", + headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.0"), + window: "5h", + expected: true, + }, + { + name: "utilization below 1.0", + headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "0.99"), + window: "5h", + expected: false, + }, + { + name: "surpassed-threshold true", + headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "true"), + window: "7d", + expected: true, + }, + { + name: "surpassed-threshold True (case insensitive)", + headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "True"), + window: "7d", + expected: true, + }, + { + name: "surpassed-threshold false", + headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "false"), + window: "7d", + expected: false, + }, + { + name: "no headers", + headers: http.Header{}, + window: "5h", + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isAnthropicWindowExceeded(tc.headers, tc.window) + if got != tc.expected { + t.Errorf("expected %v, got %v", tc.expected, got) + } + }) + } +} + +// assertAnthropicResult is a test helper that verifies the result is non-nil and +// has the expected resetAt unix timestamp. +func assertAnthropicResult(t *testing.T, result *anthropic429Result, wantUnix int64) { + t.Helper() + if result == nil { + t.Fatal("expected non-nil result") + return // unreachable, but satisfies staticcheck SA5011 + } + want := time.Unix(wantUnix, 0) + if !result.resetAt.Equal(want) { + t.Errorf("expected resetAt=%v, got %v", want, result.resetAt) + } +} + +func makeHeader(key, value string) http.Header { + h := http.Header{} + h.Set(key, value) + return h +} diff --git a/backend/internal/service/ratelimit_service_clear_test.go b/backend/internal/service/ratelimit_service_clear_test.go new file mode 100644 index 00000000..f48151ed --- /dev/null +++ b/backend/internal/service/ratelimit_service_clear_test.go @@ -0,0 +1,172 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type rateLimitClearRepoStub struct { + mockAccountRepoForGemini + clearRateLimitCalls int + clearAntigravityCalls int + clearModelRateLimitCalls int + clearTempUnschedCalls int + clearRateLimitErr error + clearAntigravityErr error + clearModelRateLimitErr error + clearTempUnschedulableErr error +} + +func (r *rateLimitClearRepoStub) ClearRateLimit(ctx context.Context, id int64) error { + r.clearRateLimitCalls++ + return r.clearRateLimitErr +} + +func (r *rateLimitClearRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + r.clearAntigravityCalls++ + return r.clearAntigravityErr +} + +func (r *rateLimitClearRepoStub) ClearModelRateLimits(ctx context.Context, id int64) error { + r.clearModelRateLimitCalls++ + return r.clearModelRateLimitErr +} + +func (r *rateLimitClearRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error { + r.clearTempUnschedCalls++ + return r.clearTempUnschedulableErr +} + +type tempUnschedCacheRecorder struct { + deletedIDs []int64 + deleteErr error +} + +func (c *tempUnschedCacheRecorder) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error { + return nil +} + +func (c *tempUnschedCacheRecorder) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) { + return nil, nil +} + +func (c *tempUnschedCacheRecorder) DeleteTempUnsched(ctx context.Context, accountID int64) error { + c.deletedIDs = append(c.deletedIDs, accountID) + return c.deleteErr +} + +func TestRateLimitService_ClearRateLimit_AlsoClearsTempUnschedulable(t *testing.T) { + repo := &rateLimitClearRepoStub{} + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 42) + require.NoError(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Equal(t, []int64{42}, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_ClearTempUnschedulableFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + clearTempUnschedulableErr: errors.New("clear temp unsched failed"), + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 7) + require.Error(t, err) + + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_ClearRateLimitFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + clearRateLimitErr: errors.New("clear rate limit failed"), + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 11) + require.Error(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 0, repo.clearAntigravityCalls) + require.Equal(t, 0, repo.clearModelRateLimitCalls) + require.Equal(t, 0, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_ClearAntigravityFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + clearAntigravityErr: errors.New("clear antigravity failed"), + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 12) + require.Error(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 0, repo.clearModelRateLimitCalls) + require.Equal(t, 0, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_ClearModelRateLimitsFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + clearModelRateLimitErr: errors.New("clear model rate limits failed"), + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 13) + require.Error(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 0, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_CacheDeleteFailedShouldNotFail(t *testing.T) { + repo := &rateLimitClearRepoStub{} + cache := &tempUnschedCacheRecorder{ + deleteErr: errors.New("cache delete failed"), + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 14) + require.NoError(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Equal(t, []int64{14}, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_WithoutTempUnschedCache(t *testing.T) { + repo := &rateLimitClearRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + err := svc.ClearRateLimit(context.Background(), 15) + require.NoError(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) +} diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index ad277ca0..b22da752 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -174,6 +174,33 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ return codes, nil } +// CreateCode creates a redeem code with caller-provided code value. +// It is primarily used by admin integrations that require an external order ID +// to be mapped to a deterministic redeem code. +func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error { + if code == nil { + return errors.New("redeem code is required") + } + code.Code = strings.TrimSpace(code.Code) + if code.Code == "" { + return errors.New("code is required") + } + if code.Type == "" { + code.Type = RedeemTypeBalance + } + if code.Type != RedeemTypeInvitation && code.Value <= 0 { + return errors.New("value must be greater than 0") + } + if code.Status == "" { + code.Status = StatusUnused + } + + if err := s.redeemRepo.Create(ctx, code); err != nil { + return fmt.Errorf("create redeem code: %w", err) + } + return nil +} + // checkRedeemRateLimit 检查用户兑换错误次数是否超限 func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error { if s.cache == nil { diff --git a/backend/internal/service/request_metadata.go b/backend/internal/service/request_metadata.go new file mode 100644 index 00000000..5c81bbf1 --- /dev/null +++ b/backend/internal/service/request_metadata.go @@ -0,0 +1,216 @@ +package service + +import ( + "context" + "sync/atomic" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" +) + +type requestMetadataContextKey struct{} + +var requestMetadataKey = requestMetadataContextKey{} + +type RequestMetadata struct { + IsMaxTokensOneHaikuRequest *bool + ThinkingEnabled *bool + PrefetchedStickyAccountID *int64 + PrefetchedStickyGroupID *int64 + SingleAccountRetry *bool + AccountSwitchCount *int +} + +var ( + requestMetadataFallbackIsMaxTokensOneHaikuTotal atomic.Int64 + requestMetadataFallbackThinkingEnabledTotal atomic.Int64 + requestMetadataFallbackPrefetchedStickyAccount atomic.Int64 + requestMetadataFallbackPrefetchedStickyGroup atomic.Int64 + requestMetadataFallbackSingleAccountRetryTotal atomic.Int64 + requestMetadataFallbackAccountSwitchCountTotal atomic.Int64 +) + +func RequestMetadataFallbackStats() (isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount int64) { + return requestMetadataFallbackIsMaxTokensOneHaikuTotal.Load(), + requestMetadataFallbackThinkingEnabledTotal.Load(), + requestMetadataFallbackPrefetchedStickyAccount.Load(), + requestMetadataFallbackPrefetchedStickyGroup.Load(), + requestMetadataFallbackSingleAccountRetryTotal.Load(), + requestMetadataFallbackAccountSwitchCountTotal.Load() +} + +func metadataFromContext(ctx context.Context) *RequestMetadata { + if ctx == nil { + return nil + } + md, _ := ctx.Value(requestMetadataKey).(*RequestMetadata) + return md +} + +func updateRequestMetadata( + ctx context.Context, + bridgeOldKeys bool, + update func(md *RequestMetadata), + legacyBridge func(ctx context.Context) context.Context, +) context.Context { + if ctx == nil { + return nil + } + current := metadataFromContext(ctx) + next := &RequestMetadata{} + if current != nil { + *next = *current + } + update(next) + ctx = context.WithValue(ctx, requestMetadataKey, next) + if bridgeOldKeys && legacyBridge != nil { + ctx = legacyBridge(ctx) + } + return ctx +} + +func WithIsMaxTokensOneHaikuRequest(ctx context.Context, value bool, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.IsMaxTokensOneHaikuRequest = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.IsMaxTokensOneHaikuRequest, value) + }) +} + +func WithThinkingEnabled(ctx context.Context, value bool, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.ThinkingEnabled = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.ThinkingEnabled, value) + }) +} + +func WithPrefetchedStickySession(ctx context.Context, accountID, groupID int64, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + account := accountID + group := groupID + md.PrefetchedStickyAccountID = &account + md.PrefetchedStickyGroupID = &group + }, func(base context.Context) context.Context { + bridged := context.WithValue(base, ctxkey.PrefetchedStickyAccountID, accountID) + return context.WithValue(bridged, ctxkey.PrefetchedStickyGroupID, groupID) + }) +} + +func WithSingleAccountRetry(ctx context.Context, value bool, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.SingleAccountRetry = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.SingleAccountRetry, value) + }) +} + +func WithAccountSwitchCount(ctx context.Context, value int, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.AccountSwitchCount = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.AccountSwitchCount, value) + }) +} + +func IsMaxTokensOneHaikuRequestFromContext(ctx context.Context) (bool, bool) { + if md := metadataFromContext(ctx); md != nil && md.IsMaxTokensOneHaikuRequest != nil { + return *md.IsMaxTokensOneHaikuRequest, true + } + if ctx == nil { + return false, false + } + if value, ok := ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok { + requestMetadataFallbackIsMaxTokensOneHaikuTotal.Add(1) + return value, true + } + return false, false +} + +func ThinkingEnabledFromContext(ctx context.Context) (bool, bool) { + if md := metadataFromContext(ctx); md != nil && md.ThinkingEnabled != nil { + return *md.ThinkingEnabled, true + } + if ctx == nil { + return false, false + } + if value, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + requestMetadataFallbackThinkingEnabledTotal.Add(1) + return value, true + } + return false, false +} + +func PrefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { + if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyGroupID != nil { + return *md.PrefetchedStickyGroupID, true + } + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.PrefetchedStickyGroupID) + switch t := v.(type) { + case int64: + requestMetadataFallbackPrefetchedStickyGroup.Add(1) + return t, true + case int: + requestMetadataFallbackPrefetchedStickyGroup.Add(1) + return int64(t), true + } + return 0, false +} + +func PrefetchedStickyAccountIDFromContext(ctx context.Context) (int64, bool) { + if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyAccountID != nil { + return *md.PrefetchedStickyAccountID, true + } + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.PrefetchedStickyAccountID) + switch t := v.(type) { + case int64: + requestMetadataFallbackPrefetchedStickyAccount.Add(1) + return t, true + case int: + requestMetadataFallbackPrefetchedStickyAccount.Add(1) + return int64(t), true + } + return 0, false +} + +func SingleAccountRetryFromContext(ctx context.Context) (bool, bool) { + if md := metadataFromContext(ctx); md != nil && md.SingleAccountRetry != nil { + return *md.SingleAccountRetry, true + } + if ctx == nil { + return false, false + } + if value, ok := ctx.Value(ctxkey.SingleAccountRetry).(bool); ok { + requestMetadataFallbackSingleAccountRetryTotal.Add(1) + return value, true + } + return false, false +} + +func AccountSwitchCountFromContext(ctx context.Context) (int, bool) { + if md := metadataFromContext(ctx); md != nil && md.AccountSwitchCount != nil { + return *md.AccountSwitchCount, true + } + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.AccountSwitchCount) + switch t := v.(type) { + case int: + requestMetadataFallbackAccountSwitchCountTotal.Add(1) + return t, true + case int64: + requestMetadataFallbackAccountSwitchCountTotal.Add(1) + return int(t), true + } + return 0, false +} diff --git a/backend/internal/service/request_metadata_test.go b/backend/internal/service/request_metadata_test.go new file mode 100644 index 00000000..7d192699 --- /dev/null +++ b/backend/internal/service/request_metadata_test.go @@ -0,0 +1,119 @@ +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestRequestMetadataWriteAndRead_NoBridge(t *testing.T) { + ctx := context.Background() + ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, false) + ctx = WithThinkingEnabled(ctx, true, false) + ctx = WithPrefetchedStickySession(ctx, 123, 456, false) + ctx = WithSingleAccountRetry(ctx, true, false) + ctx = WithAccountSwitchCount(ctx, 2, false) + + isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx) + require.True(t, ok) + require.True(t, isHaiku) + + thinking, ok := ThinkingEnabledFromContext(ctx) + require.True(t, ok) + require.True(t, thinking) + + accountID, ok := PrefetchedStickyAccountIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(123), accountID) + + groupID, ok := PrefetchedStickyGroupIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(456), groupID) + + singleRetry, ok := SingleAccountRetryFromContext(ctx) + require.True(t, ok) + require.True(t, singleRetry) + + switchCount, ok := AccountSwitchCountFromContext(ctx) + require.True(t, ok) + require.Equal(t, 2, switchCount) + + require.Nil(t, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest)) + require.Nil(t, ctx.Value(ctxkey.ThinkingEnabled)) + require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyAccountID)) + require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyGroupID)) + require.Nil(t, ctx.Value(ctxkey.SingleAccountRetry)) + require.Nil(t, ctx.Value(ctxkey.AccountSwitchCount)) +} + +func TestRequestMetadataWrite_BridgeLegacyKeys(t *testing.T) { + ctx := context.Background() + ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, true) + ctx = WithThinkingEnabled(ctx, true, true) + ctx = WithPrefetchedStickySession(ctx, 123, 456, true) + ctx = WithSingleAccountRetry(ctx, true, true) + ctx = WithAccountSwitchCount(ctx, 2, true) + + require.Equal(t, true, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest)) + require.Equal(t, true, ctx.Value(ctxkey.ThinkingEnabled)) + require.Equal(t, int64(123), ctx.Value(ctxkey.PrefetchedStickyAccountID)) + require.Equal(t, int64(456), ctx.Value(ctxkey.PrefetchedStickyGroupID)) + require.Equal(t, true, ctx.Value(ctxkey.SingleAccountRetry)) + require.Equal(t, 2, ctx.Value(ctxkey.AccountSwitchCount)) +} + +func TestRequestMetadataRead_LegacyFallbackAndStats(t *testing.T) { + beforeHaiku, beforeThinking, beforeAccount, beforeGroup, beforeSingleRetry, beforeSwitchCount := RequestMetadataFallbackStats() + + ctx := context.Background() + ctx = context.WithValue(ctx, ctxkey.IsMaxTokensOneHaikuRequest, true) + ctx = context.WithValue(ctx, ctxkey.ThinkingEnabled, true) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyAccountID, int64(321)) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(654)) + ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true) + ctx = context.WithValue(ctx, ctxkey.AccountSwitchCount, int64(3)) + + isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx) + require.True(t, ok) + require.True(t, isHaiku) + + thinking, ok := ThinkingEnabledFromContext(ctx) + require.True(t, ok) + require.True(t, thinking) + + accountID, ok := PrefetchedStickyAccountIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(321), accountID) + + groupID, ok := PrefetchedStickyGroupIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(654), groupID) + + singleRetry, ok := SingleAccountRetryFromContext(ctx) + require.True(t, ok) + require.True(t, singleRetry) + + switchCount, ok := AccountSwitchCountFromContext(ctx) + require.True(t, ok) + require.Equal(t, 3, switchCount) + + afterHaiku, afterThinking, afterAccount, afterGroup, afterSingleRetry, afterSwitchCount := RequestMetadataFallbackStats() + require.Equal(t, beforeHaiku+1, afterHaiku) + require.Equal(t, beforeThinking+1, afterThinking) + require.Equal(t, beforeAccount+1, afterAccount) + require.Equal(t, beforeGroup+1, afterGroup) + require.Equal(t, beforeSingleRetry+1, afterSingleRetry) + require.Equal(t, beforeSwitchCount+1, afterSwitchCount) +} + +func TestRequestMetadataRead_PreferMetadataOverLegacy(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, false) + ctx = WithThinkingEnabled(ctx, true, false) + + thinking, ok := ThinkingEnabledFromContext(ctx) + require.True(t, ok) + require.True(t, thinking) + require.Equal(t, false, ctx.Value(ctxkey.ThinkingEnabled)) +} diff --git a/backend/internal/service/response_header_filter.go b/backend/internal/service/response_header_filter.go new file mode 100644 index 00000000..81012b01 --- /dev/null +++ b/backend/internal/service/response_header_filter.go @@ -0,0 +1,13 @@ +package service + +import ( + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" +) + +func compileResponseHeaderFilter(cfg *config.Config) *responseheaders.CompiledHeaderFilter { + if cfg == nil { + return nil + } + return responseheaders.CompileHeaderFilter(cfg.Security.ResponseHeaders) +} diff --git a/backend/internal/service/rpm_cache.go b/backend/internal/service/rpm_cache.go new file mode 100644 index 00000000..07036219 --- /dev/null +++ b/backend/internal/service/rpm_cache.go @@ -0,0 +1,17 @@ +package service + +import "context" + +// RPMCache RPM 计数器缓存接口 +// 用于 Anthropic OAuth/SetupToken 账号的每分钟请求数限制 +type RPMCache interface { + // IncrementRPM 原子递增并返回当前分钟的计数 + // 使用 Redis 服务器时间确定 minute key,避免多实例时钟偏差 + IncrementRPM(ctx context.Context, accountID int64) (count int, err error) + + // GetRPM 获取当前分钟的 RPM 计数 + GetRPM(ctx context.Context, accountID int64) (count int, err error) + + // GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline) + GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) +} diff --git a/backend/internal/service/scheduler_shuffle_test.go b/backend/internal/service/scheduler_shuffle_test.go new file mode 100644 index 00000000..0d82b2f3 --- /dev/null +++ b/backend/internal/service/scheduler_shuffle_test.go @@ -0,0 +1,318 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// ============ shuffleWithinSortGroups 测试 ============ + +func TestShuffleWithinSortGroups_Empty(t *testing.T) { + shuffleWithinSortGroups(nil) + shuffleWithinSortGroups([]accountWithLoad{}) +} + +func TestShuffleWithinSortGroups_SingleElement(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + shuffleWithinSortGroups(accounts) + require.Equal(t, int64(1), accounts[0].account.ID) +} + +func TestShuffleWithinSortGroups_DifferentGroups_OrderPreserved(t *testing.T) { + now := time.Now() + earlier := now.Add(-1 * time.Hour) + + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + {account: &Account{ID: 3, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + + // 每个元素都属于不同组(Priority 或 LoadRate 或 LastUsedAt 不同),顺序不变 + for i := 0; i < 20; i++ { + cpy := make([]accountWithLoad, len(accounts)) + copy(cpy, accounts) + shuffleWithinSortGroups(cpy) + require.Equal(t, int64(1), cpy[0].account.ID) + require.Equal(t, int64(2), cpy[1].account.ID) + require.Equal(t, int64(3), cpy[2].account.ID) + } +} + +func TestShuffleWithinSortGroups_SameGroup_Shuffled(t *testing.T) { + now := time.Now() + // 同一秒的时间戳视为同一组 + sameSecond := time.Unix(now.Unix(), 0) + sameSecond2 := time.Unix(now.Unix(), 500_000_000) // 同一秒但不同纳秒 + + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: &sameSecond2}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + + // 多次执行,验证所有 ID 都出现在第一个位置(说明确实被打乱了) + seen := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]accountWithLoad, len(accounts)) + copy(cpy, accounts) + shuffleWithinSortGroups(cpy) + seen[cpy[0].account.ID] = true + // 无论怎么打乱,所有 ID 都应在候选中 + ids := map[int64]bool{} + for _, a := range cpy { + ids[a.account.ID] = true + } + require.True(t, ids[1] && ids[2] && ids[3]) + } + // 至少 2 个不同的 ID 出现在首位(随机性验证) + require.GreaterOrEqual(t, len(seen), 2, "shuffle should produce different orderings") +} + +func TestShuffleWithinSortGroups_NilLastUsedAt_SameGroup(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}, + {account: &Account{ID: 3, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}, + } + + seen := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]accountWithLoad, len(accounts)) + copy(cpy, accounts) + shuffleWithinSortGroups(cpy) + seen[cpy[0].account.ID] = true + } + require.GreaterOrEqual(t, len(seen), 2, "nil LastUsedAt accounts should be shuffled") +} + +func TestShuffleWithinSortGroups_MixedGroups(t *testing.T) { + now := time.Now() + earlier := now.Add(-1 * time.Hour) + sameAsNow := time.Unix(now.Unix(), 0) + + // 组1: Priority=1, LoadRate=10, LastUsedAt=earlier (ID 1) — 单元素组 + // 组2: Priority=1, LoadRate=20, LastUsedAt=now (ID 2, 3) — 双元素组 + // 组3: Priority=2, LoadRate=10, LastUsedAt=earlier (ID 4) — 单元素组 + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + {account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameAsNow}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + {account: &Account{ID: 4, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + + for i := 0; i < 20; i++ { + cpy := make([]accountWithLoad, len(accounts)) + copy(cpy, accounts) + shuffleWithinSortGroups(cpy) + + // 组间顺序不变 + require.Equal(t, int64(1), cpy[0].account.ID, "group 1 position fixed") + require.Equal(t, int64(4), cpy[3].account.ID, "group 3 position fixed") + + // 组2 内部可以打乱,但仍在位置 1 和 2 + mid := map[int64]bool{cpy[1].account.ID: true, cpy[2].account.ID: true} + require.True(t, mid[2] && mid[3], "group 2 elements should stay in positions 1-2") + } +} + +// ============ shuffleWithinPriorityAndLastUsed 测试 ============ + +func TestShuffleWithinPriorityAndLastUsed_Empty(t *testing.T) { + shuffleWithinPriorityAndLastUsed(nil, false) + shuffleWithinPriorityAndLastUsed([]*Account{}, false) +} + +func TestShuffleWithinPriorityAndLastUsed_SingleElement(t *testing.T) { + accounts := []*Account{{ID: 1, Priority: 1}} + shuffleWithinPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(1), accounts[0].ID) +} + +func TestShuffleWithinPriorityAndLastUsed_SameGroup_Shuffled(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: nil}, + {ID: 3, Priority: 1, LastUsedAt: nil}, + } + + seen := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + shuffleWithinPriorityAndLastUsed(cpy, false) + seen[cpy[0].ID] = true + } + require.GreaterOrEqual(t, len(seen), 2, "same group should be shuffled") +} + +func TestShuffleWithinPriorityAndLastUsed_DifferentPriority_OrderPreserved(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil}, + {ID: 2, Priority: 2, LastUsedAt: nil}, + {ID: 3, Priority: 3, LastUsedAt: nil}, + } + + for i := 0; i < 20; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + shuffleWithinPriorityAndLastUsed(cpy, false) + require.Equal(t, int64(1), cpy[0].ID) + require.Equal(t, int64(2), cpy[1].ID) + require.Equal(t, int64(3), cpy[2].ID) + } +} + +func TestShuffleWithinPriorityAndLastUsed_DifferentLastUsedAt_OrderPreserved(t *testing.T) { + now := time.Now() + earlier := now.Add(-1 * time.Hour) + + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: &earlier}, + {ID: 3, Priority: 1, LastUsedAt: &now}, + } + + for i := 0; i < 20; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + shuffleWithinPriorityAndLastUsed(cpy, false) + require.Equal(t, int64(1), cpy[0].ID) + require.Equal(t, int64(2), cpy[1].ID) + require.Equal(t, int64(3), cpy[2].ID) + } +} + +// ============ sameLastUsedAt 测试 ============ + +func TestSameLastUsedAt(t *testing.T) { + now := time.Now() + sameSecond := time.Unix(now.Unix(), 0) + sameSecondDiffNano := time.Unix(now.Unix(), 999_999_999) + differentSecond := now.Add(1 * time.Second) + + t.Run("both nil", func(t *testing.T) { + require.True(t, sameLastUsedAt(nil, nil)) + }) + + t.Run("one nil one not", func(t *testing.T) { + require.False(t, sameLastUsedAt(nil, &now)) + require.False(t, sameLastUsedAt(&now, nil)) + }) + + t.Run("same second different nanoseconds", func(t *testing.T) { + require.True(t, sameLastUsedAt(&sameSecond, &sameSecondDiffNano)) + }) + + t.Run("different seconds", func(t *testing.T) { + require.False(t, sameLastUsedAt(&now, &differentSecond)) + }) + + t.Run("exact same time", func(t *testing.T) { + require.True(t, sameLastUsedAt(&now, &now)) + }) +} + +// ============ sameAccountWithLoadGroup 测试 ============ + +func TestSameAccountWithLoadGroup(t *testing.T) { + now := time.Now() + sameSecond := time.Unix(now.Unix(), 0) + + t.Run("same group", func(t *testing.T) { + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + require.True(t, sameAccountWithLoadGroup(a, b)) + }) + + t.Run("different priority", func(t *testing.T) { + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + b := accountWithLoad{account: &Account{Priority: 2, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + require.False(t, sameAccountWithLoadGroup(a, b)) + }) + + t.Run("different load rate", func(t *testing.T) { + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}} + require.False(t, sameAccountWithLoadGroup(a, b)) + }) + + t.Run("different last used at", func(t *testing.T) { + later := now.Add(1 * time.Second) + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &later}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + require.False(t, sameAccountWithLoadGroup(a, b)) + }) + + t.Run("both nil LastUsedAt", func(t *testing.T) { + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}} + b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}} + require.True(t, sameAccountWithLoadGroup(a, b)) + }) +} + +// ============ sameAccountGroup 测试 ============ + +func TestSameAccountGroup(t *testing.T) { + now := time.Now() + + t.Run("same group", func(t *testing.T) { + a := &Account{Priority: 1, LastUsedAt: nil} + b := &Account{Priority: 1, LastUsedAt: nil} + require.True(t, sameAccountGroup(a, b)) + }) + + t.Run("different priority", func(t *testing.T) { + a := &Account{Priority: 1, LastUsedAt: nil} + b := &Account{Priority: 2, LastUsedAt: nil} + require.False(t, sameAccountGroup(a, b)) + }) + + t.Run("different LastUsedAt", func(t *testing.T) { + later := now.Add(1 * time.Second) + a := &Account{Priority: 1, LastUsedAt: &now} + b := &Account{Priority: 1, LastUsedAt: &later} + require.False(t, sameAccountGroup(a, b)) + }) +} + +// ============ sortAccountsByPriorityAndLastUsed 集成随机化测试 ============ + +func TestSortAccountsByPriorityAndLastUsed_WithShuffle(t *testing.T) { + t.Run("same priority and nil LastUsedAt are shuffled", func(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: nil}, + {ID: 3, Priority: 1, LastUsedAt: nil}, + } + + seen := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + sortAccountsByPriorityAndLastUsed(cpy, false) + seen[cpy[0].ID] = true + } + require.GreaterOrEqual(t, len(seen), 2, "identical sort keys should produce different orderings after shuffle") + }) + + t.Run("different priorities still sorted correctly", func(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 3, Priority: 3, LastUsedAt: &now}, + {ID: 1, Priority: 1, LastUsedAt: &now}, + {ID: 2, Priority: 2, LastUsedAt: &now}, + } + + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(1), accounts[0].ID) + require.Equal(t, int64(2), accounts[1].ID) + require.Equal(t, int64(3), accounts[2].ID) + }) +} diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 52d455b8..9f8fa14a 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -4,12 +4,13 @@ import ( "context" "encoding/json" "errors" - "log" + "log/slog" "strconv" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) var ( @@ -103,7 +104,7 @@ func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context, if s.cache != nil { cached, hit, err := s.cache.GetSnapshot(ctx, bucket) if err != nil { - log.Printf("[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err) } else if hit { return derefAccounts(cached), useMixed, nil } @@ -123,7 +124,7 @@ func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context, if s.cache != nil { if err := s.cache.SetSnapshot(fallbackCtx, bucket, accounts); err != nil { - log.Printf("[Scheduler] cache write failed: bucket=%s err=%v", bucket.String(), err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] cache write failed: bucket=%s err=%v", bucket.String(), err) } } @@ -137,7 +138,7 @@ func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int if s.cache != nil { account, err := s.cache.GetAccount(ctx, accountID) if err != nil { - log.Printf("[Scheduler] account cache read failed: id=%d err=%v", accountID, err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] account cache read failed: id=%d err=%v", accountID, err) } else if account != nil { return account, nil } @@ -167,17 +168,17 @@ func (s *SchedulerSnapshotService) runInitialRebuild() { defer cancel() buckets, err := s.cache.ListBuckets(ctx) if err != nil { - log.Printf("[Scheduler] list buckets failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] list buckets failed: %v", err) } if len(buckets) == 0 { buckets, err = s.defaultBuckets(ctx) if err != nil { - log.Printf("[Scheduler] default buckets failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] default buckets failed: %v", err) return } } if err := s.rebuildBuckets(ctx, buckets, "startup"); err != nil { - log.Printf("[Scheduler] rebuild startup failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] rebuild startup failed: %v", err) } } @@ -204,7 +205,7 @@ func (s *SchedulerSnapshotService) runFullRebuildWorker(interval time.Duration) select { case <-ticker.C: if err := s.triggerFullRebuild("interval"); err != nil { - log.Printf("[Scheduler] full rebuild failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] full rebuild failed: %v", err) } case <-s.stopCh: return @@ -221,13 +222,13 @@ func (s *SchedulerSnapshotService) pollOutbox() { watermark, err := s.cache.GetOutboxWatermark(ctx) if err != nil { - log.Printf("[Scheduler] outbox watermark read failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox watermark read failed: %v", err) return } events, err := s.outboxRepo.ListAfter(ctx, watermark, 200) if err != nil { - log.Printf("[Scheduler] outbox poll failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox poll failed: %v", err) return } if len(events) == 0 { @@ -240,14 +241,14 @@ func (s *SchedulerSnapshotService) pollOutbox() { err := s.handleOutboxEvent(eventCtx, event) cancel() if err != nil { - log.Printf("[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err) return } } lastID := events[len(events)-1].ID if err := s.cache.SetOutboxWatermark(ctx, lastID); err != nil { - log.Printf("[Scheduler] outbox watermark write failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox watermark write failed: %v", err) } else { watermarkForCheck = lastID } @@ -304,13 +305,78 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p if payload == nil { return nil } - ids := parseInt64Slice(payload["account_ids"]) - for _, id := range ids { - if err := s.handleAccountEvent(ctx, &id, payload); err != nil { - return err + if s.accountRepo == nil { + return nil + } + + rawIDs := parseInt64Slice(payload["account_ids"]) + if len(rawIDs) == 0 { + return nil + } + + ids := make([]int64, 0, len(rawIDs)) + seen := make(map[int64]struct{}, len(rawIDs)) + for _, id := range rawIDs { + if id <= 0 { + continue + } + if _, exists := seen[id]; exists { + continue + } + seen[id] = struct{}{} + ids = append(ids, id) + } + if len(ids) == 0 { + return nil + } + + preloadGroupIDs := parseInt64Slice(payload["group_ids"]) + accounts, err := s.accountRepo.GetByIDs(ctx, ids) + if err != nil { + return err + } + + found := make(map[int64]struct{}, len(accounts)) + rebuildGroupSet := make(map[int64]struct{}, len(preloadGroupIDs)) + for _, gid := range preloadGroupIDs { + if gid > 0 { + rebuildGroupSet[gid] = struct{}{} } } - return nil + + for _, account := range accounts { + if account == nil || account.ID <= 0 { + continue + } + found[account.ID] = struct{}{} + if s.cache != nil { + if err := s.cache.SetAccount(ctx, account); err != nil { + return err + } + } + for _, gid := range account.GroupIDs { + if gid > 0 { + rebuildGroupSet[gid] = struct{}{} + } + } + } + + if s.cache != nil { + for _, id := range ids { + if _, ok := found[id]; ok { + continue + } + if err := s.cache.DeleteAccount(ctx, id); err != nil { + return err + } + } + } + + rebuildGroupIDs := make([]int64, 0, len(rebuildGroupSet)) + for gid := range rebuildGroupSet { + rebuildGroupIDs = append(rebuildGroupIDs, gid) + } + return s.rebuildByGroupIDs(ctx, rebuildGroupIDs, "account_bulk_change") } func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error { @@ -444,14 +510,14 @@ func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket Sch accounts, err := s.loadAccountsFromDB(rebuildCtx, bucket, bucket.Mode == SchedulerModeMixed) if err != nil { - log.Printf("[Scheduler] rebuild failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] rebuild failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err) return err } if err := s.cache.SetSnapshot(rebuildCtx, bucket, accounts); err != nil { - log.Printf("[Scheduler] rebuild cache failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] rebuild cache failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err) return err } - log.Printf("[Scheduler] rebuild ok: bucket=%s reason=%s size=%d", bucket.String(), reason, len(accounts)) + slog.Debug("[Scheduler] rebuild ok", "bucket", bucket.String(), "reason", reason, "size", len(accounts)) return nil } @@ -464,13 +530,13 @@ func (s *SchedulerSnapshotService) triggerFullRebuild(reason string) error { buckets, err := s.cache.ListBuckets(ctx) if err != nil { - log.Printf("[Scheduler] list buckets failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] list buckets failed: %v", err) return err } if len(buckets) == 0 { buckets, err = s.defaultBuckets(ctx) if err != nil { - log.Printf("[Scheduler] default buckets failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] default buckets failed: %v", err) return err } } @@ -484,7 +550,7 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc lag := time.Since(oldest.CreatedAt) if lagSeconds := int(lag.Seconds()); lagSeconds >= s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds && s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds > 0 { - log.Printf("[Scheduler] outbox lag warning: %ds", lagSeconds) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox lag warning: %ds", lagSeconds) } if s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 && int(lag.Seconds()) >= s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds { @@ -494,12 +560,12 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc s.lagMu.Unlock() if failures >= s.cfg.Gateway.Scheduling.OutboxLagRebuildFailures { - log.Printf("[Scheduler] outbox lag rebuild triggered: lag=%s failures=%d", lag, failures) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox lag rebuild triggered: lag=%s failures=%d", lag, failures) s.lagMu.Lock() s.lagFailures = 0 s.lagMu.Unlock() if err := s.triggerFullRebuild("outbox_lag"); err != nil { - log.Printf("[Scheduler] outbox lag rebuild failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox lag rebuild failed: %v", err) } } } else { @@ -517,9 +583,9 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc return } if maxID-watermark >= int64(threshold) { - log.Printf("[Scheduler] outbox backlog rebuild triggered: backlog=%d", maxID-watermark) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox backlog rebuild triggered: backlog=%d", maxID-watermark) if err := s.triggerFullRebuild("outbox_backlog"); err != nil { - log.Printf("[Scheduler] outbox backlog rebuild failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox backlog rebuild failed: %v", err) } } } diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f5ba9d71..f7e4fb6b 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -7,16 +7,31 @@ import ( "encoding/json" "errors" "fmt" + "log/slog" + "net/url" "strconv" "strings" + "sync/atomic" + "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "golang.org/x/sync/singleflight" ) var ( - ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") - ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") + ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") + ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") + ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") + ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") + ErrDefaultSubGroupInvalid = infraerrors.BadRequest( + "DEFAULT_SUBSCRIPTION_GROUP_INVALID", + "default subscription group must exist and be subscription type", + ) + ErrDefaultSubGroupDuplicate = infraerrors.BadRequest( + "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", + "default subscription group cannot be duplicated", + ) ) type SettingRepository interface { @@ -29,12 +44,40 @@ type SettingRepository interface { Delete(ctx context.Context, key string) error } +// cachedMinVersion 缓存最低 Claude Code 版本号(进程内缓存,60s TTL) +type cachedMinVersion struct { + value string // 空字符串 = 不检查 + expiresAt int64 // unix nano +} + +// minVersionCache 最低版本号进程内缓存 +var minVersionCache atomic.Value // *cachedMinVersion + +// minVersionSF 防止缓存过期时 thundering herd +var minVersionSF singleflight.Group + +// minVersionCacheTTL 缓存有效期 +const minVersionCacheTTL = 60 * time.Second + +// minVersionErrorTTL DB 错误时的短缓存,快速重试 +const minVersionErrorTTL = 5 * time.Second + +// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context +const minVersionDBTimeout = 5 * time.Second + +// DefaultSubscriptionGroupReader validates group references used by default subscriptions. +type DefaultSubscriptionGroupReader interface { + GetByID(ctx context.Context, id int64) (*Group, error) +} + // SettingService 系统设置服务 type SettingService struct { - settingRepo SettingRepository - cfg *config.Config - onUpdate func() // Callback when settings are updated (for cache invalidation) - version string // Application version + settingRepo SettingRepository + defaultSubGroupReader DefaultSubscriptionGroupReader + cfg *config.Config + onUpdate func() // Callback when settings are updated (for cache invalidation) + onS3Update func() // Callback when Sora S3 settings are updated + version string // Application version } // NewSettingService 创建系统设置服务实例 @@ -45,6 +88,11 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti } } +// SetDefaultSubscriptionGroupReader injects an optional group reader for default subscription validation. +func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscriptionGroupReader) { + s.defaultSubGroupReader = reader +} + // GetAllSettings 获取所有系统设置 func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) { settings, err := s.settingRepo.GetAll(ctx) @@ -76,6 +124,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyHideCcsImportButton, SettingKeyPurchaseSubscriptionEnabled, SettingKeyPurchaseSubscriptionURL, + SettingKeySoraClientEnabled, + SettingKeyCustomMenuItems, SettingKeyLinuxDoConnectEnabled, } @@ -114,6 +164,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", + CustomMenuItems: settings[SettingKeyCustomMenuItems], LinuxDoOAuthEnabled: linuxDoEnabled, }, nil } @@ -124,6 +176,11 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) { s.onUpdate = callback } +// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。 +func (s *SettingService) SetOnS3UpdateCallback(callback func()) { + s.onS3Update = callback +} + // SetVersion sets the application version for injection into public settings func (s *SettingService) SetVersion(version string) { s.version = version @@ -139,26 +196,28 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any // Return a struct that matches the frontend's expected format return &struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo,omitempty"` - SiteSubtitle string `json:"site_subtitle,omitempty"` - APIBaseURL string `json:"api_base_url,omitempty"` - ContactInfo string `json:"contact_info,omitempty"` - DocURL string `json:"doc_url,omitempty"` - HomeContent string `json:"home_content,omitempty"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - Version string `json:"version,omitempty"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo,omitempty"` + SiteSubtitle string `json:"site_subtitle,omitempty"` + APIBaseURL string `json:"api_base_url,omitempty"` + ContactInfo string `json:"contact_info,omitempty"` + DocURL string `json:"doc_url,omitempty"` + HomeContent string `json:"home_content,omitempty"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` + SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems json.RawMessage `json:"custom_menu_items"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + Version string `json:"version,omitempty"` }{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, @@ -178,13 +237,126 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + SoraClientEnabled: settings.SoraClientEnabled, + CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: s.version, }, nil } +// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON +// array string, returning only items with visibility != "admin". +func filterUserVisibleMenuItems(raw string) json.RawMessage { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return json.RawMessage("[]") + } + var items []struct { + Visibility string `json:"visibility"` + } + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return json.RawMessage("[]") + } + + // Parse full items to preserve all fields + var fullItems []json.RawMessage + if err := json.Unmarshal([]byte(raw), &fullItems); err != nil { + return json.RawMessage("[]") + } + + var filtered []json.RawMessage + for i, item := range items { + if item.Visibility != "admin" { + filtered = append(filtered, fullItems[i]) + } + } + if len(filtered) == 0 { + return json.RawMessage("[]") + } + result, err := json.Marshal(filtered) + if err != nil { + return json.RawMessage("[]") + } + return result +} + +// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url +// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. +func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { + settings, err := s.GetPublicSettings(ctx) + if err != nil { + return nil, err + } + + seen := make(map[string]struct{}) + var origins []string + + addOrigin := func(rawURL string) { + if origin := extractOriginFromURL(rawURL); origin != "" { + if _, ok := seen[origin]; !ok { + seen[origin] = struct{}{} + origins = append(origins, origin) + } + } + } + + // purchase subscription URL + if settings.PurchaseSubscriptionEnabled { + addOrigin(settings.PurchaseSubscriptionURL) + } + + // all custom menu items (including admin-only, since CSP must allow all iframes) + for _, item := range parseCustomMenuItemURLs(settings.CustomMenuItems) { + addOrigin(item) + } + + return origins, nil +} + +// extractOriginFromURL returns the scheme+host origin from rawURL. +// Only http and https schemes are accepted. +func extractOriginFromURL(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "" + } + u, err := url.Parse(rawURL) + if err != nil || u.Host == "" { + return "" + } + if u.Scheme != "http" && u.Scheme != "https" { + return "" + } + return u.Scheme + "://" + u.Host +} + +// parseCustomMenuItemURLs extracts URLs from a raw JSON array of custom menu items. +func parseCustomMenuItemURLs(raw string) []string { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return nil + } + var items []struct { + URL string `json:"url"` + } + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil + } + urls := make([]string, 0, len(items)) + for _, item := range items { + if item.URL != "" { + urls = append(urls, item.URL) + } + } + return urls +} + // UpdateSettings 更新系统设置 func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { + if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { + return err + } + updates := make(map[string]string) // 注册设置 @@ -232,10 +404,17 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton) updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) + updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) + updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) + defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) + if err != nil { + return fmt.Errorf("marshal default subscriptions: %w", err) + } + updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON) // Model fallback configuration updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback) @@ -256,13 +435,63 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds) } - err := s.settingRepo.SetMultiple(ctx, updates) - if err == nil && s.onUpdate != nil { - s.onUpdate() // Invalidate cache after settings update + // Claude Code version check + updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion + + err = s.settingRepo.SetMultiple(ctx, updates) + if err == nil { + // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 + minVersionSF.Forget("min_version") + minVersionCache.Store(&cachedMinVersion{ + value: settings.MinClaudeCodeVersion, + expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(), + }) + if s.onUpdate != nil { + s.onUpdate() // Invalidate cache after settings update + } } return err } +func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error { + if len(items) == 0 { + return nil + } + + checked := make(map[int64]struct{}, len(items)) + for _, item := range items { + if item.GroupID <= 0 { + continue + } + if _, ok := checked[item.GroupID]; ok { + return ErrDefaultSubGroupDuplicate.WithMetadata(map[string]string{ + "group_id": strconv.FormatInt(item.GroupID, 10), + }) + } + checked[item.GroupID] = struct{}{} + if s.defaultSubGroupReader == nil { + continue + } + + group, err := s.defaultSubGroupReader.GetByID(ctx, item.GroupID) + if err != nil { + if errors.Is(err, ErrGroupNotFound) { + return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{ + "group_id": strconv.FormatInt(item.GroupID, 10), + }) + } + return fmt.Errorf("get default subscription group %d: %w", item.GroupID, err) + } + if !group.IsSubscriptionType() { + return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{ + "group_id": strconv.FormatInt(item.GroupID, 10), + }) + } + } + + return nil +} + // IsRegistrationEnabled 检查是否开放注册 func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) @@ -362,6 +591,15 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { return s.cfg.Default.UserBalance } +// GetDefaultSubscriptions 获取新用户默认订阅配置列表。 +func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting { + value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions) + if err != nil { + return nil + } + return parseDefaultSubscriptions(value) +} + // InitializeDefaultSettings 初始化默认设置 func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 检查是否已有设置 @@ -383,8 +621,11 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeySiteLogo: "", SettingKeyPurchaseSubscriptionEnabled: "false", SettingKeyPurchaseSubscriptionURL: "", + SettingKeySoraClientEnabled: "false", + SettingKeyCustomMenuItems: "[]", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyDefaultSubscriptions: "[]", SettingKeySMTPPort: "587", SettingKeySMTPUseTLS: "false", // Model fallback defaults @@ -402,6 +643,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyOpsRealtimeMonitoringEnabled: "true", SettingKeyOpsQueryModeDefault: "auto", SettingKeyOpsMetricsIntervalSeconds: "60", + + // Claude Code version check (default: empty = disabled) + SettingKeyMinClaudeCodeVersion: "", } return s.settingRepo.SetMultiple(ctx, defaults) @@ -436,6 +680,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", + CustomMenuItems: settings[SettingKeyCustomMenuItems], } // 解析整数类型 @@ -457,6 +703,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.DefaultBalance = s.cfg.Default.UserBalance } + result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions]) // 敏感信息直接返回,方便测试连接时使用 result.SMTPPassword = settings[SettingKeySMTPPassword] @@ -526,6 +773,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } } + // Claude Code version check + result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion] + return result } @@ -538,6 +788,31 @@ func isFalseSettingValue(value string) bool { } } +func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + + var items []DefaultSubscriptionSetting + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil + } + + normalized := make([]DefaultSubscriptionSetting, 0, len(items)) + for _, item := range items { + if item.GroupID <= 0 || item.ValidityDays <= 0 { + continue + } + if item.ValidityDays > MaxValidityDays { + item.ValidityDays = MaxValidityDays + } + normalized = append(normalized, item) + } + + return normalized +} + // getStringOrDefault 获取字符串值或默认值 func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string { if value, ok := settings[key]; ok && value != "" { @@ -823,6 +1098,53 @@ func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamT return &settings, nil } +// GetMinClaudeCodeVersion 获取最低 Claude Code 版本号要求 +// 使用进程内 atomic.Value 缓存,60 秒 TTL,热路径零锁开销 +// singleflight 防止缓存过期时 thundering herd +// 返回空字符串表示不做版本检查 +func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string { + if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok { + if time.Now().UnixNano() < cached.expiresAt { + return cached.value + } + } + // singleflight: 同一时刻只有一个 goroutine 查询 DB,其余复用结果 + result, err, _ := minVersionSF.Do("min_version", func() (any, error) { + // 二次检查,避免排队的 goroutine 重复查询 + if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok { + if time.Now().UnixNano() < cached.expiresAt { + return cached.value, nil + } + } + // 使用独立 context:断开请求取消链,避免客户端断连导致空值被长期缓存 + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), minVersionDBTimeout) + defer cancel() + value, err := s.settingRepo.GetValue(dbCtx, SettingKeyMinClaudeCodeVersion) + if err != nil { + // fail-open: DB 错误时不阻塞请求,但记录日志并使用短 TTL 快速重试 + slog.Warn("failed to get min claude code version setting, skipping version check", "error", err) + minVersionCache.Store(&cachedMinVersion{ + value: "", + expiresAt: time.Now().Add(minVersionErrorTTL).UnixNano(), + }) + return "", nil + } + minVersionCache.Store(&cachedMinVersion{ + value: value, + expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(), + }) + return value, nil + }) + if err != nil { + return "" + } + ver, ok := result.(string) + if !ok { + return "" + } + return ver +} + // SetStreamTimeoutSettings 设置流超时处理配置 func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error { if settings == nil { @@ -854,3 +1176,607 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data)) } + +type soraS3ProfilesStore struct { + ActiveProfileID string `json:"active_profile_id"` + Items []soraS3ProfileStoreItem `json:"items"` +} + +type soraS3ProfileStoreItem struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` + UpdatedAt string `json:"updated_at"` +} + +// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置) +func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { + profiles, err := s.ListSoraS3Profiles(ctx) + if err != nil { + return nil, err + } + + activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) + if activeProfile == nil { + return &SoraS3Settings{}, nil + } + + return &SoraS3Settings{ + Enabled: activeProfile.Enabled, + Endpoint: activeProfile.Endpoint, + Region: activeProfile.Region, + Bucket: activeProfile.Bucket, + AccessKeyID: activeProfile.AccessKeyID, + SecretAccessKey: activeProfile.SecretAccessKey, + SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured, + Prefix: activeProfile.Prefix, + ForcePathStyle: activeProfile.ForcePathStyle, + CDNURL: activeProfile.CDNURL, + DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes, + }, nil +} + +// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置) +func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return err + } + + now := time.Now().UTC().Format(time.RFC3339) + activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID) + if activeIndex < 0 { + activeID := "default" + if hasSoraS3ProfileID(store.Items, activeID) { + activeID = fmt.Sprintf("default-%d", time.Now().Unix()) + } + store.Items = append(store.Items, soraS3ProfileStoreItem{ + ProfileID: activeID, + Name: "Default", + UpdatedAt: now, + }) + store.ActiveProfileID = activeID + activeIndex = len(store.Items) - 1 + } + + active := store.Items[activeIndex] + active.Enabled = settings.Enabled + active.Endpoint = strings.TrimSpace(settings.Endpoint) + active.Region = strings.TrimSpace(settings.Region) + active.Bucket = strings.TrimSpace(settings.Bucket) + active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID) + active.Prefix = strings.TrimSpace(settings.Prefix) + active.ForcePathStyle = settings.ForcePathStyle + active.CDNURL = strings.TrimSpace(settings.CDNURL) + active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0) + if settings.SecretAccessKey != "" { + active.SecretAccessKey = settings.SecretAccessKey + } + active.UpdatedAt = now + store.Items[activeIndex] = active + + return s.persistSoraS3ProfilesStore(ctx, store) +} + +// ListSoraS3Profiles 获取 Sora S3 多配置列表 +func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) { + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + return convertSoraS3ProfilesStore(store), nil +} + +// CreateSoraS3Profile 创建 Sora S3 配置 +func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) { + if profile == nil { + return nil, fmt.Errorf("profile cannot be nil") + } + + profileID := strings.TrimSpace(profile.ProfileID) + if profileID == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + name := strings.TrimSpace(profile.Name) + if name == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + if hasSoraS3ProfileID(store.Items, profileID) { + return nil, ErrSoraS3ProfileExists + } + + now := time.Now().UTC().Format(time.RFC3339) + store.Items = append(store.Items, soraS3ProfileStoreItem{ + ProfileID: profileID, + Name: name, + Enabled: profile.Enabled, + Endpoint: strings.TrimSpace(profile.Endpoint), + Region: strings.TrimSpace(profile.Region), + Bucket: strings.TrimSpace(profile.Bucket), + AccessKeyID: strings.TrimSpace(profile.AccessKeyID), + SecretAccessKey: profile.SecretAccessKey, + Prefix: strings.TrimSpace(profile.Prefix), + ForcePathStyle: profile.ForcePathStyle, + CDNURL: strings.TrimSpace(profile.CDNURL), + DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0), + UpdatedAt: now, + }) + + if setActive || store.ActiveProfileID == "" { + store.ActiveProfileID = profileID + } + + if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { + return nil, err + } + + profiles := convertSoraS3ProfilesStore(store) + created := findSoraS3ProfileByID(profiles.Items, profileID) + if created == nil { + return nil, ErrSoraS3ProfileNotFound + } + return created, nil +} + +// UpdateSoraS3Profile 更新 Sora S3 配置 +func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) { + if profile == nil { + return nil, fmt.Errorf("profile cannot be nil") + } + + targetID := strings.TrimSpace(profileID) + if targetID == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + + targetIndex := findSoraS3ProfileIndex(store.Items, targetID) + if targetIndex < 0 { + return nil, ErrSoraS3ProfileNotFound + } + + target := store.Items[targetIndex] + name := strings.TrimSpace(profile.Name) + if name == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required") + } + target.Name = name + target.Enabled = profile.Enabled + target.Endpoint = strings.TrimSpace(profile.Endpoint) + target.Region = strings.TrimSpace(profile.Region) + target.Bucket = strings.TrimSpace(profile.Bucket) + target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID) + target.Prefix = strings.TrimSpace(profile.Prefix) + target.ForcePathStyle = profile.ForcePathStyle + target.CDNURL = strings.TrimSpace(profile.CDNURL) + target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0) + if profile.SecretAccessKey != "" { + target.SecretAccessKey = profile.SecretAccessKey + } + target.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + store.Items[targetIndex] = target + + if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { + return nil, err + } + + profiles := convertSoraS3ProfilesStore(store) + updated := findSoraS3ProfileByID(profiles.Items, targetID) + if updated == nil { + return nil, ErrSoraS3ProfileNotFound + } + return updated, nil +} + +// DeleteSoraS3Profile 删除 Sora S3 配置 +func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error { + targetID := strings.TrimSpace(profileID) + if targetID == "" { + return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return err + } + + targetIndex := findSoraS3ProfileIndex(store.Items, targetID) + if targetIndex < 0 { + return ErrSoraS3ProfileNotFound + } + + store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...) + if store.ActiveProfileID == targetID { + store.ActiveProfileID = "" + if len(store.Items) > 0 { + store.ActiveProfileID = store.Items[0].ProfileID + } + } + + return s.persistSoraS3ProfilesStore(ctx, store) +} + +// SetActiveSoraS3Profile 设置激活的 Sora S3 配置 +func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) { + targetID := strings.TrimSpace(profileID) + if targetID == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + + targetIndex := findSoraS3ProfileIndex(store.Items, targetID) + if targetIndex < 0 { + return nil, ErrSoraS3ProfileNotFound + } + + store.ActiveProfileID = targetID + store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339) + if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { + return nil, err + } + + profiles := convertSoraS3ProfilesStore(store) + active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) + if active == nil { + return nil, ErrSoraS3ProfileNotFound + } + return active, nil +} + +func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) { + raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles) + if err == nil { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return &soraS3ProfilesStore{}, nil + } + var store soraS3ProfilesStore + if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil { + legacy, legacyErr := s.getLegacySoraS3Settings(ctx) + if legacyErr != nil { + return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr) + } + if isEmptyLegacySoraS3Settings(legacy) { + return &soraS3ProfilesStore{}, nil + } + now := time.Now().UTC().Format(time.RFC3339) + return &soraS3ProfilesStore{ + ActiveProfileID: "default", + Items: []soraS3ProfileStoreItem{ + { + ProfileID: "default", + Name: "Default", + Enabled: legacy.Enabled, + Endpoint: strings.TrimSpace(legacy.Endpoint), + Region: strings.TrimSpace(legacy.Region), + Bucket: strings.TrimSpace(legacy.Bucket), + AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), + SecretAccessKey: legacy.SecretAccessKey, + Prefix: strings.TrimSpace(legacy.Prefix), + ForcePathStyle: legacy.ForcePathStyle, + CDNURL: strings.TrimSpace(legacy.CDNURL), + DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), + UpdatedAt: now, + }, + }, + }, nil + } + normalized := normalizeSoraS3ProfilesStore(store) + return &normalized, nil + } + + if !errors.Is(err, ErrSettingNotFound) { + return nil, fmt.Errorf("get sora s3 profiles: %w", err) + } + + legacy, legacyErr := s.getLegacySoraS3Settings(ctx) + if legacyErr != nil { + return nil, legacyErr + } + if isEmptyLegacySoraS3Settings(legacy) { + return &soraS3ProfilesStore{}, nil + } + + now := time.Now().UTC().Format(time.RFC3339) + return &soraS3ProfilesStore{ + ActiveProfileID: "default", + Items: []soraS3ProfileStoreItem{ + { + ProfileID: "default", + Name: "Default", + Enabled: legacy.Enabled, + Endpoint: strings.TrimSpace(legacy.Endpoint), + Region: strings.TrimSpace(legacy.Region), + Bucket: strings.TrimSpace(legacy.Bucket), + AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), + SecretAccessKey: legacy.SecretAccessKey, + Prefix: strings.TrimSpace(legacy.Prefix), + ForcePathStyle: legacy.ForcePathStyle, + CDNURL: strings.TrimSpace(legacy.CDNURL), + DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), + UpdatedAt: now, + }, + }, + }, nil +} + +func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error { + if store == nil { + return fmt.Errorf("sora s3 profiles store cannot be nil") + } + + normalized := normalizeSoraS3ProfilesStore(*store) + data, err := json.Marshal(normalized) + if err != nil { + return fmt.Errorf("marshal sora s3 profiles: %w", err) + } + + updates := map[string]string{ + SettingKeySoraS3Profiles: string(data), + } + + active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID) + if active == nil { + updates[SettingKeySoraS3Enabled] = "false" + updates[SettingKeySoraS3Endpoint] = "" + updates[SettingKeySoraS3Region] = "" + updates[SettingKeySoraS3Bucket] = "" + updates[SettingKeySoraS3AccessKeyID] = "" + updates[SettingKeySoraS3Prefix] = "" + updates[SettingKeySoraS3ForcePathStyle] = "false" + updates[SettingKeySoraS3CDNURL] = "" + updates[SettingKeySoraDefaultStorageQuotaBytes] = "0" + updates[SettingKeySoraS3SecretAccessKey] = "" + } else { + updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled) + updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint) + updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region) + updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket) + updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID) + updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix) + updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle) + updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL) + updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10) + updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey + } + + if err := s.settingRepo.SetMultiple(ctx, updates); err != nil { + return err + } + + if s.onUpdate != nil { + s.onUpdate() + } + if s.onS3Update != nil { + s.onS3Update() + } + return nil +} + +func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { + keys := []string{ + SettingKeySoraS3Enabled, + SettingKeySoraS3Endpoint, + SettingKeySoraS3Region, + SettingKeySoraS3Bucket, + SettingKeySoraS3AccessKeyID, + SettingKeySoraS3SecretAccessKey, + SettingKeySoraS3Prefix, + SettingKeySoraS3ForcePathStyle, + SettingKeySoraS3CDNURL, + SettingKeySoraDefaultStorageQuotaBytes, + } + + values, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return nil, fmt.Errorf("get legacy sora s3 settings: %w", err) + } + + result := &SoraS3Settings{ + Enabled: values[SettingKeySoraS3Enabled] == "true", + Endpoint: values[SettingKeySoraS3Endpoint], + Region: values[SettingKeySoraS3Region], + Bucket: values[SettingKeySoraS3Bucket], + AccessKeyID: values[SettingKeySoraS3AccessKeyID], + SecretAccessKey: values[SettingKeySoraS3SecretAccessKey], + SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "", + Prefix: values[SettingKeySoraS3Prefix], + ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true", + CDNURL: values[SettingKeySoraS3CDNURL], + } + if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil { + result.DefaultStorageQuotaBytes = v + } + return result, nil +} + +func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore { + seen := make(map[string]struct{}, len(store.Items)) + normalized := soraS3ProfilesStore{ + ActiveProfileID: strings.TrimSpace(store.ActiveProfileID), + Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)), + } + now := time.Now().UTC().Format(time.RFC3339) + + for idx := range store.Items { + item := store.Items[idx] + item.ProfileID = strings.TrimSpace(item.ProfileID) + if item.ProfileID == "" { + item.ProfileID = fmt.Sprintf("profile-%d", idx+1) + } + if _, exists := seen[item.ProfileID]; exists { + continue + } + seen[item.ProfileID] = struct{}{} + + item.Name = strings.TrimSpace(item.Name) + if item.Name == "" { + item.Name = item.ProfileID + } + item.Endpoint = strings.TrimSpace(item.Endpoint) + item.Region = strings.TrimSpace(item.Region) + item.Bucket = strings.TrimSpace(item.Bucket) + item.AccessKeyID = strings.TrimSpace(item.AccessKeyID) + item.Prefix = strings.TrimSpace(item.Prefix) + item.CDNURL = strings.TrimSpace(item.CDNURL) + item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0) + item.UpdatedAt = strings.TrimSpace(item.UpdatedAt) + if item.UpdatedAt == "" { + item.UpdatedAt = now + } + normalized.Items = append(normalized.Items, item) + } + + if len(normalized.Items) == 0 { + normalized.ActiveProfileID = "" + return normalized + } + + if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 { + return normalized + } + + normalized.ActiveProfileID = normalized.Items[0].ProfileID + return normalized +} + +func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList { + if store == nil { + return &SoraS3ProfileList{} + } + items := make([]SoraS3Profile, 0, len(store.Items)) + for idx := range store.Items { + item := store.Items[idx] + items = append(items, SoraS3Profile{ + ProfileID: item.ProfileID, + Name: item.Name, + IsActive: item.ProfileID == store.ActiveProfileID, + Enabled: item.Enabled, + Endpoint: item.Endpoint, + Region: item.Region, + Bucket: item.Bucket, + AccessKeyID: item.AccessKeyID, + SecretAccessKey: item.SecretAccessKey, + SecretAccessKeyConfigured: item.SecretAccessKey != "", + Prefix: item.Prefix, + ForcePathStyle: item.ForcePathStyle, + CDNURL: item.CDNURL, + DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes, + UpdatedAt: item.UpdatedAt, + }) + } + return &SoraS3ProfileList{ + ActiveProfileID: store.ActiveProfileID, + Items: items, + } +} + +func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile { + for idx := range items { + if items[idx].ProfileID == activeProfileID { + return &items[idx] + } + } + if len(items) == 0 { + return nil + } + return &items[0] +} + +func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile { + for idx := range items { + if items[idx].ProfileID == profileID { + return &items[idx] + } + } + return nil +} + +func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem { + for idx := range items { + if items[idx].ProfileID == activeProfileID { + return &items[idx] + } + } + if len(items) == 0 { + return nil + } + return &items[0] +} + +func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int { + for idx := range items { + if items[idx].ProfileID == profileID { + return idx + } + } + return -1 +} + +func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool { + return findSoraS3ProfileIndex(items, profileID) >= 0 +} + +func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool { + if settings == nil { + return true + } + if settings.Enabled { + return false + } + if strings.TrimSpace(settings.Endpoint) != "" { + return false + } + if strings.TrimSpace(settings.Region) != "" { + return false + } + if strings.TrimSpace(settings.Bucket) != "" { + return false + } + if strings.TrimSpace(settings.AccessKeyID) != "" { + return false + } + if settings.SecretAccessKey != "" { + return false + } + if strings.TrimSpace(settings.Prefix) != "" { + return false + } + if strings.TrimSpace(settings.CDNURL) != "" { + return false + } + return settings.DefaultStorageQuotaBytes == 0 +} + +func maxInt64(value int64, min int64) int64 { + if value < min { + return min + } + return value +} diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go new file mode 100644 index 00000000..ec64511f --- /dev/null +++ b/backend/internal/service/setting_service_update_test.go @@ -0,0 +1,182 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type settingUpdateRepoStub struct { + updates map[string]string +} + +func (s *settingUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *settingUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingUpdateRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *settingUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.updates = make(map[string]string, len(settings)) + for k, v := range settings { + s.updates[k] = v + } + return nil +} + +func (s *settingUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingUpdateRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +type defaultSubGroupReaderStub struct { + byID map[int64]*Group + errBy map[int64]error + calls []int64 +} + +func (s *defaultSubGroupReaderStub) GetByID(ctx context.Context, id int64) (*Group, error) { + s.calls = append(s.calls, id) + if err, ok := s.errBy[id]; ok { + return nil, err + } + if g, ok := s.byID[id]; ok { + return g, nil + } + return nil, ErrGroupNotFound +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_ValidGroup(t *testing.T) { + repo := &settingUpdateRepoStub{} + groupReader := &defaultSubGroupReaderStub{ + byID: map[int64]*Group{ + 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription}, + }, + } + svc := NewSettingService(repo, &config.Config{}) + svc.SetDefaultSubscriptionGroupReader(groupReader) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + }, + }) + require.NoError(t, err) + require.Equal(t, []int64{11}, groupReader.calls) + + raw, ok := repo.updates[SettingKeyDefaultSubscriptions] + require.True(t, ok) + + var got []DefaultSubscriptionSetting + require.NoError(t, json.Unmarshal([]byte(raw), &got)) + require.Equal(t, []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + }, got) +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNonSubscriptionGroup(t *testing.T) { + repo := &settingUpdateRepoStub{} + groupReader := &defaultSubGroupReaderStub{ + byID: map[int64]*Group{ + 12: {ID: 12, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := NewSettingService(repo, &config.Config{}) + svc.SetDefaultSubscriptionGroupReader(groupReader) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 12, ValidityDays: 7}, + }, + }) + require.Error(t, err) + require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err)) + require.Nil(t, repo.updates) +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNotFoundGroup(t *testing.T) { + repo := &settingUpdateRepoStub{} + groupReader := &defaultSubGroupReaderStub{ + errBy: map[int64]error{ + 13: ErrGroupNotFound, + }, + } + svc := NewSettingService(repo, &config.Config{}) + svc.SetDefaultSubscriptionGroupReader(groupReader) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 13, ValidityDays: 7}, + }, + }) + require.Error(t, err) + require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err)) + require.Equal(t, "13", infraerrors.FromError(err).Metadata["group_id"]) + require.Nil(t, repo.updates) +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroup(t *testing.T) { + repo := &settingUpdateRepoStub{} + groupReader := &defaultSubGroupReaderStub{ + byID: map[int64]*Group{ + 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription}, + }, + } + svc := NewSettingService(repo, &config.Config{}) + svc.SetDefaultSubscriptionGroupReader(groupReader) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + {GroupID: 11, ValidityDays: 60}, + }, + }) + require.Error(t, err) + require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err)) + require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"]) + require.Nil(t, repo.updates) +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroupWithoutGroupReader(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + {GroupID: 11, ValidityDays: 60}, + }, + }) + require.Error(t, err) + require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err)) + require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"]) + require.Nil(t, repo.updates) +} + +func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) { + got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`) + require.Equal(t, []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + {GroupID: 11, ValidityDays: 60}, + {GroupID: 12, ValidityDays: MaxValidityDays}, + }, got) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 0c7bab67..9f0de600 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -39,9 +39,12 @@ type SystemSettings struct { HideCcsImportButton bool PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string + SoraClientEnabled bool + CustomMenuItems string // JSON array of custom menu items - DefaultConcurrency int - DefaultBalance float64 + DefaultConcurrency int + DefaultBalance float64 + DefaultSubscriptions []DefaultSubscriptionSetting // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -59,6 +62,14 @@ type SystemSettings struct { OpsRealtimeMonitoringEnabled bool OpsQueryModeDefault string OpsMetricsIntervalSeconds int + + // Claude Code version check + MinClaudeCodeVersion string +} + +type DefaultSubscriptionSetting struct { + GroupID int64 `json:"group_id"` + ValidityDays int `json:"validity_days"` } type PublicSettings struct { @@ -81,11 +92,53 @@ type PublicSettings struct { PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string + SoraClientEnabled bool + CustomMenuItems string // JSON array of custom menu items LinuxDoOAuthEnabled bool Version string } +// SoraS3Settings Sora S3 存储配置 +type SoraS3Settings struct { + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端 + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +// SoraS3Profile Sora S3 多配置项(服务内部模型) +type SoraS3Profile struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端 + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` + UpdatedAt string `json:"updated_at"` +} + +// SoraS3ProfileList Sora S3 多配置列表 +type SoraS3ProfileList struct { + ActiveProfileID string `json:"active_profile_id"` + Items []SoraS3Profile `json:"items"` +} + // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) type StreamTimeoutSettings struct { // Enabled 是否启用流超时处理 diff --git a/backend/internal/service/sora_account_service.go b/backend/internal/service/sora_account_service.go new file mode 100644 index 00000000..eccc1acf --- /dev/null +++ b/backend/internal/service/sora_account_service.go @@ -0,0 +1,40 @@ +package service + +import "context" + +// SoraAccountRepository Sora 账号扩展表仓储接口 +// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。 +// +// 设计说明: +// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本 +// - Sora gateway 优先读取此表的字段以获得更好的查询性能 +// - 主表 accounts 通过 credentials JSON 字段也存储相同信息 +// - Token 刷新时需要同时更新两个表以保持数据一致性 +type SoraAccountRepository interface { + // Upsert 创建或更新 Sora 账号扩展信息 + // accountID: 关联的 accounts.id + // updates: 要更新的字段,支持 access_token、refresh_token、session_token + // + // 如果记录不存在则创建,存在则更新。 + // 用于: + // 1. 创建 Sora 账号时初始化扩展表 + // 2. Token 刷新时同步更新扩展表 + Upsert(ctx context.Context, accountID int64, updates map[string]any) error + + // GetByAccountID 根据账号 ID 获取 Sora 扩展信息 + // 返回 nil, nil 表示记录不存在(非错误) + GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error) + + // Delete 删除 Sora 账号扩展信息 + // 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理 + Delete(ctx context.Context, accountID int64) error +} + +// SoraAccount Sora 账号扩展信息 +// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本 +type SoraAccount struct { + AccountID int64 // 关联的 accounts.id + AccessToken string // OAuth access_token + RefreshToken string // OAuth refresh_token + SessionToken string // Session token(可选,用于 ST→AT 兜底) +} diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go new file mode 100644 index 00000000..0a914d2d --- /dev/null +++ b/backend/internal/service/sora_client.go @@ -0,0 +1,117 @@ +package service + +import ( + "context" + "fmt" + "net/http" +) + +// SoraClient 定义直连 Sora 的任务操作接口。 +type SoraClient interface { + Enabled() bool + UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) + CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) + CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) + CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) + UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) + GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) + DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) + UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) + FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) + SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error + DeleteCharacter(ctx context.Context, account *Account, characterID string) error + PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) + DeletePost(ctx context.Context, account *Account, postID string) error + GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) + EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) + GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) + GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) +} + +// SoraImageRequest 图片生成请求参数 +type SoraImageRequest struct { + Prompt string + Width int + Height int + MediaID string +} + +// SoraVideoRequest 视频生成请求参数 +type SoraVideoRequest struct { + Prompt string + Orientation string + Frames int + Model string + Size string + VideoCount int + MediaID string + RemixTargetID string + CameoIDs []string +} + +// SoraStoryboardRequest 分镜视频生成请求参数 +type SoraStoryboardRequest struct { + Prompt string + Orientation string + Frames int + Model string + Size string + MediaID string +} + +// SoraImageTaskStatus 图片任务状态 +type SoraImageTaskStatus struct { + ID string + Status string + ProgressPct float64 + URLs []string + ErrorMsg string +} + +// SoraVideoTaskStatus 视频任务状态 +type SoraVideoTaskStatus struct { + ID string + Status string + ProgressPct int + URLs []string + GenerationID string + ErrorMsg string +} + +// SoraCameoStatus 角色处理中间态 +type SoraCameoStatus struct { + Status string + StatusMessage string + DisplayNameHint string + UsernameHint string + ProfileAssetURL string + InstructionSetHint any + InstructionSet any +} + +// SoraCharacterFinalizeRequest 角色定稿请求参数 +type SoraCharacterFinalizeRequest struct { + CameoID string + Username string + DisplayName string + ProfileAssetPointer string + InstructionSet any +} + +// SoraUpstreamError 上游错误 +type SoraUpstreamError struct { + StatusCode int + Message string + Headers http.Header + Body []byte +} + +func (e *SoraUpstreamError) Error() string { + if e == nil { + return "sora upstream error" + } + if e.Message != "" { + return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message) + } + return fmt.Sprintf("sora upstream error: %d", e.StatusCode) +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go new file mode 100644 index 00000000..ab6871bb --- /dev/null +++ b/backend/internal/service/sora_gateway_service.go @@ -0,0 +1,1553 @@ +package service + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math" + "math/rand" + "mime" + "net" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" +) + +const soraImageInputMaxBytes = 20 << 20 +const soraImageInputMaxRedirects = 3 +const soraImageInputTimeout = 20 * time.Second +const soraVideoInputMaxBytes = 200 << 20 +const soraVideoInputMaxRedirects = 3 +const soraVideoInputTimeout = 60 * time.Second + +var soraImageSizeMap = map[string]string{ + "gpt-image": "360", + "gpt-image-landscape": "540", + "gpt-image-portrait": "540", +} + +var soraBlockedHostnames = map[string]struct{}{ + "localhost": {}, + "localhost.localdomain": {}, + "metadata.google.internal": {}, + "metadata.google.internal.": {}, +} + +var soraBlockedCIDRs = mustParseCIDRs([]string{ + "0.0.0.0/8", + "10.0.0.0/8", + "100.64.0.0/10", + "127.0.0.0/8", + "169.254.0.0/16", + "172.16.0.0/12", + "192.168.0.0/16", + "224.0.0.0/4", + "240.0.0.0/4", + "::/128", + "::1/128", + "fc00::/7", + "fe80::/10", +}) + +// SoraGatewayService handles forwarding requests to Sora upstream. +type SoraGatewayService struct { + soraClient SoraClient + rateLimitService *RateLimitService + httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传 + cfg *config.Config +} + +type soraWatermarkOptions struct { + Enabled bool + ParseMethod string + ParseURL string + ParseToken string + FallbackOnFailure bool + DeletePost bool +} + +type soraCharacterOptions struct { + SetPublic bool + DeleteAfterGenerate bool +} + +type soraCharacterFlowResult struct { + CameoID string + CharacterID string + Username string + DisplayName string +} + +var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`) +var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`) +var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`) +var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`) + +type soraPreflightChecker interface { + PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error +} + +func NewSoraGatewayService( + soraClient SoraClient, + rateLimitService *RateLimitService, + httpUpstream HTTPUpstream, + cfg *config.Config, +) *SoraGatewayService { + return &SoraGatewayService{ + soraClient: soraClient, + rateLimitService: rateLimitService, + httpUpstream: httpUpstream, + cfg: cfg, + } +} + +func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) { + startTime := time.Now() + + // apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient + if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" { + if s.httpUpstream == nil { + s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream) + return nil, errors.New("httpUpstream not configured for sora apikey forwarding") + } + return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime) + } + + if s.soraClient == nil || !s.soraClient.Enabled() { + if c != nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Sora 上游未配置", + }, + }) + } + return nil, errors.New("sora upstream not configured") + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream) + return nil, fmt.Errorf("parse request: %w", err) + } + reqModel, _ := reqBody["model"].(string) + reqStream, _ := reqBody["stream"].(bool) + if strings.TrimSpace(reqModel) == "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) + return nil, errors.New("model is required") + } + + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != "" && mappedModel != reqModel { + reqModel = mappedModel + } + + modelCfg, ok := GetSoraModelConfig(reqModel) + if !ok { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) + return nil, fmt.Errorf("unsupported model: %s", reqModel) + } + prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) + prompt = strings.TrimSpace(prompt) + imageInput = strings.TrimSpace(imageInput) + videoInput = strings.TrimSpace(videoInput) + remixTargetID = strings.TrimSpace(remixTargetID) + + if videoInput != "" && modelCfg.Type != "video" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream) + return nil, errors.New("video input only supports video models") + } + if videoInput != "" && imageInput != "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream) + return nil, errors.New("image input and video input cannot be used together") + } + characterOnly := videoInput != "" && prompt == "" + if modelCfg.Type == "prompt_enhance" && prompt == "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) + return nil, errors.New("prompt is required") + } + if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) + return nil, errors.New("prompt is required") + } + + reqCtx, cancel := s.withSoraTimeout(ctx, reqStream) + if cancel != nil { + defer cancel() + } + if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly { + if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + } + + if modelCfg.Type == "prompt_enhance" { + enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS) + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + content := strings.TrimSpace(enhancedPrompt) + if content == "" { + content = prompt + } + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) + } + return &ForwardResult{ + RequestID: "", + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", + }, nil + } + + characterOpts := parseSoraCharacterOptions(reqBody) + watermarkOpts := parseSoraWatermarkOptions(reqBody) + var characterResult *soraCharacterFlowResult + if videoInput != "" { + videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput) + if videoErr != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream) + return nil, videoErr + } + characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts) + if videoErr != nil { + return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream) + } + if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly { + characterID := strings.TrimSpace(characterResult.CharacterID) + defer func() { + cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second) + defer cancelCleanup() + if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil { + log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err) + } + }() + } + if characterOnly { + content := "角色创建成功" + if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { + content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username)) + } + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + resp := buildSoraNonStreamResponse(content, reqModel) + if characterResult != nil { + resp["character_id"] = characterResult.CharacterID + resp["cameo_id"] = characterResult.CameoID + resp["character_username"] = characterResult.Username + resp["character_display_name"] = characterResult.DisplayName + } + c.JSON(http.StatusOK, resp) + } + return &ForwardResult{ + RequestID: "", + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", + }, nil + } + if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { + prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt) + } + } + + var imageData []byte + imageFilename := "" + if imageInput != "" { + decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput) + if err != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream) + return nil, err + } + imageData = decoded + imageFilename = filename + } + + mediaID := "" + if len(imageData) > 0 { + uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename) + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + mediaID = uploadID + } + + taskID := "" + var err error + videoCount := parseSoraVideoCount(reqBody) + switch modelCfg.Type { + case "image": + taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{ + Prompt: prompt, + Width: modelCfg.Width, + Height: modelCfg.Height, + MediaID: mediaID, + }) + case "video": + if remixTargetID == "" && isSoraStoryboardPrompt(prompt) { + taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{ + Prompt: formatSoraStoryboardPrompt(prompt), + Orientation: modelCfg.Orientation, + Frames: modelCfg.Frames, + Model: modelCfg.Model, + Size: modelCfg.Size, + MediaID: mediaID, + }) + } else { + taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{ + Prompt: prompt, + Orientation: modelCfg.Orientation, + Frames: modelCfg.Frames, + Model: modelCfg.Model, + Size: modelCfg.Size, + VideoCount: videoCount, + MediaID: mediaID, + RemixTargetID: remixTargetID, + CameoIDs: extractSoraCameoIDs(reqBody), + }) + } + default: + err = fmt.Errorf("unsupported model type: %s", modelCfg.Type) + } + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + + if clientStream && c != nil { + s.prepareSoraStream(c, taskID) + } + + var mediaURLs []string + videoGenerationID := "" + mediaType := modelCfg.Type + imageCount := 0 + imageSize := "" + switch modelCfg.Type { + case "image": + urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream) + if pollErr != nil { + return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) + } + mediaURLs = urls + imageCount = len(urls) + imageSize = soraImageSizeFromModel(reqModel) + case "video": + videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream) + if pollErr != nil { + return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) + } + if videoStatus != nil { + mediaURLs = videoStatus.URLs + videoGenerationID = strings.TrimSpace(videoStatus.GenerationID) + } + default: + mediaType = "prompt" + } + + watermarkPostID := "" + if modelCfg.Type == "video" && watermarkOpts.Enabled { + watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts) + if watermarkErr != nil { + if !watermarkOpts.FallbackOnFailure { + return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream) + } + log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr) + } else if strings.TrimSpace(watermarkURL) != "" { + mediaURLs = []string{strings.TrimSpace(watermarkURL)} + watermarkPostID = strings.TrimSpace(postID) + } + } + + // 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。 + // 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。 + finalURLs := s.normalizeSoraMediaURLs(mediaURLs) + if watermarkPostID != "" && watermarkOpts.DeletePost { + if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil { + log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr) + } + } + + content := buildSoraContent(mediaType, finalURLs) + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + response := buildSoraNonStreamResponse(content, reqModel) + if len(finalURLs) > 0 { + response["media_url"] = finalURLs[0] + if len(finalURLs) > 1 { + response["media_urls"] = finalURLs + } + } + c.JSON(http.StatusOK, response) + } + + return &ForwardResult{ + RequestID: taskID, + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: mediaType, + MediaURL: firstMediaURL(finalURLs), + ImageCount: imageCount, + ImageSize: imageSize, + }, nil +} + +func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if s == nil || s.cfg == nil { + return ctx, nil + } + timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds + if stream { + timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds + } + if timeoutSeconds <= 0 { + return ctx, nil + } + return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) +} + +func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions { + opts := soraWatermarkOptions{ + Enabled: parseBoolWithDefault(body, "watermark_free", false), + ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))), + ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")), + ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")), + FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true), + DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false), + } + if opts.ParseMethod == "" { + opts.ParseMethod = "third_party" + } + return opts +} + +func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions { + return soraCharacterOptions{ + SetPublic: parseBoolWithDefault(body, "character_set_public", true), + DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true), + } +} + +func parseSoraVideoCount(body map[string]any) int { + if body == nil { + return 1 + } + keys := []string{"video_count", "videos", "n_variants"} + for _, key := range keys { + count := parseIntWithDefault(body, key, 0) + if count > 0 { + return clampInt(count, 1, 3) + } + } + return 1 +} + +func parseBoolWithDefault(body map[string]any, key string, def bool) bool { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + switch typed := val.(type) { + case bool: + return typed + case int: + return typed != 0 + case int32: + return typed != 0 + case int64: + return typed != 0 + case float64: + return typed != 0 + case string: + typed = strings.ToLower(strings.TrimSpace(typed)) + if typed == "true" || typed == "1" || typed == "yes" { + return true + } + if typed == "false" || typed == "0" || typed == "no" { + return false + } + } + return def +} + +func parseStringWithDefault(body map[string]any, key, def string) string { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + if str, ok := val.(string); ok { + return str + } + return def +} + +func parseIntWithDefault(body map[string]any, key string, def int) int { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + switch typed := val.(type) { + case int: + return typed + case int32: + return int(typed) + case int64: + return int(typed) + case float64: + return int(typed) + case string: + parsed, err := strconv.Atoi(strings.TrimSpace(typed)) + if err == nil { + return parsed + } + } + return def +} + +func clampInt(v, minVal, maxVal int) int { + if v < minVal { + return minVal + } + if v > maxVal { + return maxVal + } + return v +} + +func extractSoraCameoIDs(body map[string]any) []string { + if body == nil { + return nil + } + raw, ok := body["cameo_ids"] + if !ok { + return nil + } + switch typed := raw.(type) { + case []string: + out := make([]string, 0, len(typed)) + for _, item := range typed { + item = strings.TrimSpace(item) + if item != "" { + out = append(out, item) + } + } + return out + case []any: + out := make([]string, 0, len(typed)) + for _, item := range typed { + str, ok := item.(string) + if !ok { + continue + } + str = strings.TrimSpace(str) + if str != "" { + out = append(out, str) + } + } + return out + default: + return nil + } +} + +func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) { + cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData) + if err != nil { + return nil, err + } + + cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID) + if err != nil { + return nil, err + } + username := processSoraCharacterUsername(cameoStatus.UsernameHint) + displayName := strings.TrimSpace(cameoStatus.DisplayNameHint) + if displayName == "" { + displayName = "Character" + } + profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL) + if profileAssetURL == "" { + return nil, errors.New("profile asset url not found in cameo status") + } + + avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL) + if err != nil { + return nil, err + } + assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData) + if err != nil { + return nil, err + } + instructionSet := cameoStatus.InstructionSetHint + if instructionSet == nil { + instructionSet = cameoStatus.InstructionSet + } + + characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{ + CameoID: strings.TrimSpace(cameoID), + Username: username, + DisplayName: displayName, + ProfileAssetPointer: assetPointer, + InstructionSet: instructionSet, + }) + if err != nil { + return nil, err + } + + if opts.SetPublic { + if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil { + return nil, err + } + } + + return &soraCharacterFlowResult{ + CameoID: strings.TrimSpace(cameoID), + CharacterID: strings.TrimSpace(characterID), + Username: strings.TrimSpace(username), + DisplayName: displayName, + }, nil +} + +func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + timeout := 10 * time.Minute + interval := 5 * time.Second + maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds())) + if maxAttempts < 1 { + maxAttempts = 1 + } + + var lastErr error + consecutiveErrors := 0 + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID) + if err != nil { + lastErr = err + consecutiveErrors++ + if consecutiveErrors >= 3 { + break + } + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + continue + } + consecutiveErrors = 0 + if status == nil { + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + continue + } + currentStatus := strings.ToLower(strings.TrimSpace(status.Status)) + statusMessage := strings.TrimSpace(status.StatusMessage) + if currentStatus == "failed" { + if statusMessage == "" { + statusMessage = "character creation failed" + } + return nil, errors.New(statusMessage) + } + if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" { + return status, nil + } + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + } + if lastErr != nil { + return nil, fmt.Errorf("poll cameo status failed: %w", lastErr) + } + return nil, errors.New("cameo processing timeout") +} + +func processSoraCharacterUsername(usernameHint string) string { + usernameHint = strings.TrimSpace(usernameHint) + if usernameHint == "" { + usernameHint = "character" + } + if strings.Contains(usernameHint, ".") { + parts := strings.Split(usernameHint, ".") + usernameHint = strings.TrimSpace(parts[len(parts)-1]) + } + if usernameHint == "" { + usernameHint = "character" + } + return fmt.Sprintf("%s%d", usernameHint, rand.Intn(900)+100) +} + +func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) { + generationID = strings.TrimSpace(generationID) + if generationID == "" { + return "", "", errors.New("generation id is required for watermark-free mode") + } + postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID) + if err != nil { + return "", "", err + } + postID = strings.TrimSpace(postID) + if postID == "" { + return "", "", errors.New("watermark-free publish returned empty post id") + } + + switch opts.ParseMethod { + case "custom": + urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID) + if parseErr != nil { + return "", postID, parseErr + } + return strings.TrimSpace(urlVal), postID, nil + case "", "third_party": + return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil + default: + return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod) + } +} + +func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 402, 403, 404, 429, 529: + return true + default: + return statusCode >= 500 + } +} + +func buildSoraNonStreamResponse(content, model string) map[string]any { + return map[string]any{ + "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": content, + }, + "finish_reason": "stop", + }, + }, + } +} + +func soraImageSizeFromModel(model string) string { + modelLower := strings.ToLower(model) + if size, ok := soraImageSizeMap[modelLower]; ok { + return size + } + if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") { + return "540" + } + return "360" +} + +func soraProErrorMessage(model, upstreamMsg string) string { + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "sora2pro-hd") { + return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号" + } + if strings.Contains(modelLower, "sora2pro") { + return "当前账号无法使用 Sora Pro 模型,请更换模型或账号" + } + return "" +} + +func firstMediaURL(urls []string) string { + if len(urls) == 0 { + return "" + } + return urls[0] +} + +func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string { + if path == "" { + return path + } + prefix := "/sora/media" + values := url.Values{} + if rawQuery != "" { + if parsed, err := url.ParseQuery(rawQuery); err == nil { + values = parsed + } + } + + signKey := "" + ttlSeconds := 0 + if s != nil && s.cfg != nil { + signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey) + ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds + } + values.Del("sig") + values.Del("expires") + signingQuery := values.Encode() + if signKey != "" && ttlSeconds > 0 { + expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix() + signature := SignSoraMediaURL(path, signingQuery, expires, signKey) + if signature != "" { + values.Set("expires", strconv.FormatInt(expires, 10)) + values.Set("sig", signature) + prefix = "/sora/media-signed" + } + } + + encoded := values.Encode() + if encoded == "" { + return prefix + path + } + return prefix + path + "?" + encoded +} + +func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) { + if c == nil { + return + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if strings.TrimSpace(requestID) != "" { + c.Header("x-request-id", requestID) + } +} + +func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) { + if c == nil { + return nil, nil + } + writer := c.Writer + flusher, _ := writer.(http.Flusher) + + chunk := map[string]any{ + "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{ + "content": content, + }, + }, + }, + } + encoded, _ := jsonMarshalRaw(chunk) + if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil { + return nil, err + } + if flusher != nil { + flusher.Flush() + } + ms := int(time.Since(startTime).Milliseconds()) + finalChunk := map[string]any{ + "id": chunk["id"], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{}, + "finish_reason": "stop", + }, + }, + } + finalEncoded, _ := jsonMarshalRaw(finalChunk) + if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil { + return &ms, err + } + if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil { + return &ms, err + } + if flusher != nil { + flusher.Flush() + } + return &ms, nil +} + +func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) { + if c == nil { + return + } + if stream { + flusher, _ := c.Writer.(http.Flusher) + errorData := map[string]any{ + "error": map[string]string{ + "type": errType, + "message": message, + }, + } + jsonBytes, err := json.Marshal(errorData) + if err != nil { + _ = c.Error(err) + return + } + errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) + _, _ = fmt.Fprint(c.Writer, errorEvent) + _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + return + } + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error { + if err == nil { + return nil + } + var upstreamErr *SoraUpstreamError + if errors.As(err, &upstreamErr) { + accountID := int64(0) + if account != nil { + accountID = account.ID + } + logger.LegacyPrintf( + "service.sora", + "[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s", + accountID, + model, + upstreamErr.StatusCode, + strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")), + strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")), + strings.TrimSpace(upstreamErr.Message), + truncateForLog(upstreamErr.Body, 1024), + ) + if s.rateLimitService != nil && account != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) + } + if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { + var responseHeaders http.Header + if upstreamErr.Headers != nil { + responseHeaders = upstreamErr.Headers.Clone() + } + return &UpstreamFailoverError{ + StatusCode: upstreamErr.StatusCode, + ResponseBody: upstreamErr.Body, + ResponseHeaders: responseHeaders, + } + } + msg := upstreamErr.Message + if override := soraProErrorMessage(model, msg); override != "" { + msg = override + } + s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream) + return err + } + if errors.Is(err, context.DeadlineExceeded) { + s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream) + return err + } + s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream) + return err +} + +func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { + interval := s.pollInterval() + maxAttempts := s.pollMaxAttempts() + lastPing := time.Now() + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetImageTask(ctx, account, taskID) + if err != nil { + return nil, err + } + switch strings.ToLower(status.Status) { + case "succeeded", "completed": + return status.URLs, nil + case "failed": + if status.ErrorMsg != "" { + return nil, errors.New(status.ErrorMsg) + } + return nil, errors.New("sora image generation failed") + } + if stream { + s.maybeSendPing(c, &lastPing) + } + if err := sleepWithContext(ctx, interval); err != nil { + return nil, err + } + } + return nil, errors.New("sora image generation timeout") +} + +func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) { + interval := s.pollInterval() + maxAttempts := s.pollMaxAttempts() + lastPing := time.Now() + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetVideoTask(ctx, account, taskID) + if err != nil { + return nil, err + } + switch strings.ToLower(status.Status) { + case "completed", "succeeded": + return status, nil + case "failed": + if status.ErrorMsg != "" { + return nil, errors.New(status.ErrorMsg) + } + return nil, errors.New("sora video generation failed") + } + if stream { + s.maybeSendPing(c, &lastPing) + } + if err := sleepWithContext(ctx, interval); err != nil { + return nil, err + } + } + return nil, errors.New("sora video generation timeout") +} + +func (s *SoraGatewayService) pollInterval() time.Duration { + if s == nil || s.cfg == nil { + return 2 * time.Second + } + interval := s.cfg.Sora.Client.PollIntervalSeconds + if interval <= 0 { + interval = 2 + } + return time.Duration(interval) * time.Second +} + +func (s *SoraGatewayService) pollMaxAttempts() int { + if s == nil || s.cfg == nil { + return 600 + } + maxAttempts := s.cfg.Sora.Client.MaxPollAttempts + if maxAttempts <= 0 { + maxAttempts = 600 + } + return maxAttempts +} + +func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) { + if c == nil { + return + } + interval := 10 * time.Second + if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 { + interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second + } + if time.Since(*lastPing) < interval { + return + } + if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil { + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } + *lastPing = time.Now() + } +} + +func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string { + if len(urls) == 0 { + return urls + } + output := make([]string, 0, len(urls)) + for _, raw := range urls { + raw = strings.TrimSpace(raw) + if raw == "" { + continue + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + output = append(output, raw) + continue + } + pathVal := raw + if !strings.HasPrefix(pathVal, "/") { + pathVal = "/" + pathVal + } + output = append(output, s.buildSoraMediaURL(pathVal, "")) + } + return output +} + +// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符, +// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。 +func jsonMarshalRaw(v any) ([]byte, error) { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + if err := enc.Encode(v); err != nil { + return nil, err + } + // Encode 会追加换行符,去掉它 + b := buf.Bytes() + if len(b) > 0 && b[len(b)-1] == '\n' { + b = b[:len(b)-1] + } + return b, nil +} + +func buildSoraContent(mediaType string, urls []string) string { + switch mediaType { + case "image": + parts := make([]string, 0, len(urls)) + for _, u := range urls { + parts = append(parts, fmt.Sprintf("![image](%s)", u)) + } + return strings.Join(parts, "\n") + case "video": + if len(urls) == 0 { + return "" + } + return fmt.Sprintf("```html\n\n```", urls[0]) + default: + return "" + } +} + +func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) { + if body == nil { + return "", "", "", "" + } + if v, ok := body["remix_target_id"].(string); ok { + remixTargetID = strings.TrimSpace(v) + } + if v, ok := body["image"].(string); ok { + imageInput = v + } + if v, ok := body["video"].(string); ok { + videoInput = v + } + if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" { + prompt = v + } + if messages, ok := body["messages"].([]any); ok { + builder := strings.Builder{} + for _, raw := range messages { + msg, ok := raw.(map[string]any) + if !ok { + continue + } + role, _ := msg["role"].(string) + if role != "" && role != "user" { + continue + } + content := msg["content"] + text, img, vid := parseSoraMessageContent(content) + if text != "" { + if builder.Len() > 0 { + _, _ = builder.WriteString("\n") + } + _, _ = builder.WriteString(text) + } + if imageInput == "" && img != "" { + imageInput = img + } + if videoInput == "" && vid != "" { + videoInput = vid + } + } + if prompt == "" { + prompt = builder.String() + } + } + if remixTargetID == "" { + remixTargetID = extractRemixTargetIDFromPrompt(prompt) + } + prompt = cleanRemixLinkFromPrompt(prompt) + return prompt, imageInput, videoInput, remixTargetID +} + +func parseSoraMessageContent(content any) (text, imageInput, videoInput string) { + switch val := content.(type) { + case string: + return val, "", "" + case []any: + builder := strings.Builder{} + for _, item := range val { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + t, _ := itemMap["type"].(string) + switch t { + case "text": + if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" { + if builder.Len() > 0 { + _, _ = builder.WriteString("\n") + } + _, _ = builder.WriteString(txt) + } + case "image_url": + if imageInput == "" { + if urlVal, ok := itemMap["image_url"].(map[string]any); ok { + imageInput = fmt.Sprintf("%v", urlVal["url"]) + } else if urlStr, ok := itemMap["image_url"].(string); ok { + imageInput = urlStr + } + } + case "video_url": + if videoInput == "" { + if urlVal, ok := itemMap["video_url"].(map[string]any); ok { + videoInput = fmt.Sprintf("%v", urlVal["url"]) + } else if urlStr, ok := itemMap["video_url"].(string); ok { + videoInput = urlStr + } + } + } + } + return builder.String(), imageInput, videoInput + default: + return "", "", "" + } +} + +func isSoraStoryboardPrompt(prompt string) bool { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return false + } + return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1 +} + +func formatSoraStoryboardPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return "" + } + matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1) + if len(matches) == 0 { + return prompt + } + firstBracketPos := strings.Index(prompt, "[") + instructions := "" + if firstBracketPos > 0 { + instructions = strings.TrimSpace(prompt[:firstBracketPos]) + } + shots := make([]string, 0, len(matches)) + for i, match := range matches { + if len(match) < 3 { + continue + } + duration := strings.TrimSpace(match[1]) + scene := strings.TrimSpace(match[2]) + if scene == "" { + continue + } + shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene)) + } + if len(shots) == 0 { + return prompt + } + timeline := strings.Join(shots, "\n\n") + if instructions == "" { + return timeline + } + return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions) +} + +func extractRemixTargetIDFromPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return "" + } + return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt)) +} + +func cleanRemixLinkFromPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return prompt + } + cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "") + cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "") + cleaned = strings.Join(strings.Fields(cleaned), " ") + return strings.TrimSpace(cleaned) +} + +func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) { + raw := strings.TrimSpace(input) + if raw == "" { + return nil, "", errors.New("empty image input") + } + if strings.HasPrefix(raw, "data:") { + parts := strings.SplitN(raw, ",", 2) + if len(parts) != 2 { + return nil, "", errors.New("invalid data url") + } + meta := parts[0] + payload := parts[1] + decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes) + if err != nil { + return nil, "", err + } + ext := "" + if strings.HasPrefix(meta, "data:") { + metaParts := strings.SplitN(meta[5:], ";", 2) + if len(metaParts) > 0 { + if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 { + ext = exts[0] + } + } + } + filename := "image" + ext + return decoded, filename, nil + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + return downloadSoraImageInput(ctx, raw) + } + decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes) + if err != nil { + return nil, "", errors.New("invalid base64 image") + } + return decoded, "image.png", nil +} + +func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) { + raw := strings.TrimSpace(input) + if raw == "" { + return nil, errors.New("empty video input") + } + if strings.HasPrefix(raw, "data:") { + parts := strings.SplitN(raw, ",", 2) + if len(parts) != 2 { + return nil, errors.New("invalid video data url") + } + decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes) + if err != nil { + return nil, errors.New("invalid base64 video") + } + if len(decoded) == 0 { + return nil, errors.New("empty video data") + } + return decoded, nil + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + return downloadSoraVideoInput(ctx, raw) + } + decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes) + if err != nil { + return nil, errors.New("invalid base64 video") + } + if len(decoded) == 0 { + return nil, errors.New("empty video data") + } + return decoded, nil +} + +func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) { + parsed, err := validateSoraRemoteURL(rawURL) + if err != nil { + return nil, "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) + if err != nil { + return nil, "", err + } + client := &http.Client{ + Timeout: soraImageInputTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= soraImageInputMaxRedirects { + return errors.New("too many redirects") + } + return validateSoraRemoteURLValue(req.URL) + }, + } + resp, err := client.Do(req) + if err != nil { + return nil, "", err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode) + } + data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes)) + if err != nil { + return nil, "", err + } + ext := fileExtFromURL(parsed.String()) + if ext == "" { + ext = fileExtFromContentType(resp.Header.Get("Content-Type")) + } + filename := "image" + ext + return data, filename, nil +} + +func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) { + parsed, err := validateSoraRemoteURL(rawURL) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) + if err != nil { + return nil, err + } + client := &http.Client{ + Timeout: soraVideoInputTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= soraVideoInputMaxRedirects { + return errors.New("too many redirects") + } + return validateSoraRemoteURLValue(req.URL) + }, + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("download video failed: %d", resp.StatusCode) + } + data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes)) + if err != nil { + return nil, err + } + if len(data) == 0 { + return nil, errors.New("empty video content") + } + return data, nil +} + +func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) { + if maxBytes <= 0 { + return nil, errors.New("invalid max bytes limit") + } + decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + limited := io.LimitReader(decoder, maxBytes+1) + data, err := io.ReadAll(limited) + if err != nil { + return nil, err + } + if int64(len(data)) > maxBytes { + return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes) + } + return data, nil +} + +func validateSoraRemoteURL(raw string) (*url.URL, error) { + if strings.TrimSpace(raw) == "" { + return nil, errors.New("empty remote url") + } + parsed, err := url.Parse(raw) + if err != nil { + return nil, fmt.Errorf("invalid remote url: %w", err) + } + if err := validateSoraRemoteURLValue(parsed); err != nil { + return nil, err + } + return parsed, nil +} + +func validateSoraRemoteURLValue(parsed *url.URL) error { + if parsed == nil { + return errors.New("invalid remote url") + } + scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) + if scheme != "http" && scheme != "https" { + return errors.New("only http/https remote url is allowed") + } + if parsed.User != nil { + return errors.New("remote url cannot contain userinfo") + } + host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + if host == "" { + return errors.New("remote url missing host") + } + if _, blocked := soraBlockedHostnames[host]; blocked { + return errors.New("remote url is not allowed") + } + if ip := net.ParseIP(host); ip != nil { + if isSoraBlockedIP(ip) { + return errors.New("remote url is not allowed") + } + return nil + } + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("resolve remote url failed: %w", err) + } + for _, ip := range ips { + if isSoraBlockedIP(ip) { + return errors.New("remote url is not allowed") + } + } + return nil +} + +func isSoraBlockedIP(ip net.IP) bool { + if ip == nil { + return true + } + for _, cidr := range soraBlockedCIDRs { + if cidr.Contains(ip) { + return true + } + } + return false +} + +func mustParseCIDRs(values []string) []*net.IPNet { + out := make([]*net.IPNet, 0, len(values)) + for _, val := range values { + _, cidr, err := net.ParseCIDR(val) + if err != nil { + continue + } + out = append(out, cidr) + } + return out +} diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go new file mode 100644 index 00000000..206636ff --- /dev/null +++ b/backend/internal/service/sora_gateway_service_test.go @@ -0,0 +1,558 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +var _ SoraClient = (*stubSoraClientForPoll)(nil) + +type stubSoraClientForPoll struct { + imageStatus *SoraImageTaskStatus + videoStatus *SoraVideoTaskStatus + imageCalls int + videoCalls int + enhanced string + enhanceErr error + storyboard bool + videoReq SoraVideoRequest + parseErr error + postCalls int + deleteCalls int +} + +func (s *stubSoraClientForPoll) Enabled() bool { return true } +func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { + return "", nil +} +func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { + return "task-image", nil +} +func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { + s.videoReq = req + return "task-video", nil +} +func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { + s.storyboard = true + return "task-video", nil +} +func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { + return "cameo-1", nil +} +func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + return &SoraCameoStatus{ + Status: "finalized", + StatusMessage: "Completed", + DisplayNameHint: "Character", + UsernameHint: "user.character", + ProfileAssetURL: "https://example.com/avatar.webp", + }, nil +} +func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { + return []byte("avatar"), nil +} +func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { + return "asset-pointer", nil +} +func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { + return "character-1", nil +} +func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { + return nil +} +func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { + return nil +} +func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { + s.postCalls++ + return "s_post", nil +} +func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error { + s.deleteCalls++ + return nil +} +func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { + if s.parseErr != nil { + return "", s.parseErr + } + return "https://example.com/no-watermark.mp4", nil +} +func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { + if s.enhanced != "" { + return s.enhanced, s.enhanceErr + } + return "enhanced prompt", s.enhanceErr +} +func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { + s.imageCalls++ + return s.imageStatus, nil +} +func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { + s.videoCalls++ + return s.videoStatus, nil +} + +func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) { + client := &stubSoraClientForPoll{ + imageStatus: &SoraImageTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/a.png"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + service := NewSoraGatewayService(client, nil, nil, cfg) + + urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false) + require.NoError(t, err) + require.Equal(t, []string{"https://example.com/a.png"}, urls) + require.Equal(t, 1, client.imageCalls) +} + +func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { + client := &stubSoraClientForPoll{ + enhanced: "cinematic prompt", + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + } + body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "prompt", result.MediaType) + require.Equal(t, "prompt-enhance-short-10s", result.Model) +} + +func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/v.mp4"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, client.storyboard) +} + +func TestSoraGatewayService_ForwardVideoCount(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/v.mp4"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 3, client.videoReq.VideoCount) +} + +func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) { + client := &stubSoraClientForPoll{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "prompt", result.MediaType) + require.Equal(t, 0, client.videoCalls) +} + +func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/original.mp4"}, + GenerationID: "gen_1", + }, + parseErr: errors.New("parse failed"), + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "https://example.com/original.mp4", result.MediaURL) + require.Equal(t, 1, client.postCalls) + require.Equal(t, 0, client.deleteCalls) +} + +func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/original.mp4"}, + GenerationID: "gen_1", + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL) + require.Equal(t, 1, client.postCalls) + require.Equal(t, 1, client.deleteCalls) +} + +func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "failed", + ErrorMsg: "reject", + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + service := NewSoraGatewayService(client, nil, nil, cfg) + + status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false) + require.Error(t, err) + require.Nil(t, status) + require.Contains(t, err.Error(), "reject") + require.Equal(t, 1, client.videoCalls) +} + +func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) { + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + SoraMediaSigningKey: "test-key", + SoraMediaSignedURLTTLSeconds: 600, + }, + } + service := NewSoraGatewayService(nil, nil, nil, cfg) + + url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "") + require.Contains(t, url, "/sora/media-signed") + require.Contains(t, url, "expires=") + require.Contains(t, url, "sig=") +} + +func TestNormalizeSoraMediaURLs_Empty(t *testing.T) { + svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) + result := svc.normalizeSoraMediaURLs(nil) + require.Empty(t, result) + + result = svc.normalizeSoraMediaURLs([]string{}) + require.Empty(t, result) +} + +func TestNormalizeSoraMediaURLs_HTTPUrls(t *testing.T) { + svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) + urls := []string{"https://example.com/a.png", "http://example.com/b.mp4"} + result := svc.normalizeSoraMediaURLs(urls) + require.Equal(t, urls, result) +} + +func TestNormalizeSoraMediaURLs_LocalPaths(t *testing.T) { + cfg := &config.Config{} + svc := NewSoraGatewayService(nil, nil, nil, cfg) + urls := []string{"/image/2025/01/a.png", "video/2025/01/b.mp4"} + result := svc.normalizeSoraMediaURLs(urls) + require.Len(t, result, 2) + require.Contains(t, result[0], "/sora/media") + require.Contains(t, result[1], "/sora/media") +} + +func TestNormalizeSoraMediaURLs_SkipsBlank(t *testing.T) { + svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) + urls := []string{"https://example.com/a.png", "", " ", "https://example.com/b.png"} + result := svc.normalizeSoraMediaURLs(urls) + require.Len(t, result, 2) +} + +func TestBuildSoraContent_Image(t *testing.T) { + content := buildSoraContent("image", []string{"https://a.com/1.png", "https://a.com/2.png"}) + require.Contains(t, content, "![image](https://a.com/1.png)") + require.Contains(t, content, "![image](https://a.com/2.png)") +} + +func TestBuildSoraContent_Video(t *testing.T) { + content := buildSoraContent("video", []string{"https://a.com/v.mp4"}) + require.Contains(t, content, "