diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index d21d0684..01c00bb9 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -19,7 +19,7 @@ jobs: cache: true - name: Verify Go version run: | - go version | grep -q 'go1.25.7' + go version | grep -q 'go1.26.1' - name: Unit tests working-directory: backend run: make test-unit @@ -38,10 +38,10 @@ jobs: cache: true - name: Verify Go version run: | - go version | grep -q 'go1.25.7' + go version | grep -q 'go1.26.1' - name: golangci-lint uses: golangci/golangci-lint-action@v9 with: - version: v2.7 + version: v2.9 args: --timeout=30m working-directory: backend \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a1c6aa23..5c0524c8 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -115,7 +115,7 @@ jobs: - name: Verify Go version run: | - go version | grep -q 'go1.25.7' + go version | grep -q 'go1.26.1' # Docker setup for GoReleaser - name: Set up QEMU diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml index db922509..cc5a90cf 100644 --- a/.github/workflows/security-scan.yml +++ b/.github/workflows/security-scan.yml @@ -23,7 +23,7 @@ jobs: cache-dependency-path: backend/go.sum - name: Verify Go version run: | - go version | grep -q 'go1.25.7' + go version | grep -q 'go1.26.1' - name: Run govulncheck working-directory: backend run: | diff --git a/AGENTS.md b/AGENTS.md deleted file mode 100644 index bb5bb465..00000000 --- a/AGENTS.md +++ /dev/null @@ -1,105 +0,0 @@ -# Repository Guidelines - -## Project Structure & Module Organization -- `backend/`: Go service. `cmd/server` is the entrypoint, `internal/` contains handlers/services/repositories/server wiring, `ent/` holds Ent schemas and generated ORM code, `migrations/` stores DB migrations, and `internal/web/dist/` is the embedded frontend build output. -- `frontend/`: Vue 3 + TypeScript app. Main folders are `src/api`, `src/components`, `src/views`, `src/stores`, `src/composables`, `src/utils`, and test files in `src/**/__tests__`. -- `deploy/`: Docker and deployment assets (`docker-compose*.yml`, `.env.example`, `config.example.yaml`). -- `openspec/`: Spec-driven change docs (`changes//{proposal,design,tasks}.md`). -- `tools/`: Utility scripts (security/perf checks). - -## Build, Test, and Development Commands -```bash -make build # Build backend + frontend -make test # Backend tests + frontend lint/typecheck -cd backend && make build # Build backend binary -cd backend && make test-unit # Go unit tests -cd backend && make test-integration # Go integration tests -cd backend && make test # go test ./... + golangci-lint -cd frontend && pnpm install --frozen-lockfile -cd frontend && pnpm dev # Vite dev server -cd frontend && pnpm build # Type-check + production build -cd frontend && pnpm test:run # Vitest run -cd frontend && pnpm test:coverage # Vitest + coverage report -python3 tools/secret_scan.py # Secret scan -``` - -## Coding Style & Naming Conventions -- Go: format with `gofmt`; lint with `golangci-lint` (`backend/.golangci.yml`). -- Respect layering: `internal/service` and `internal/handler` must not import `internal/repository`, `gorm`, or `redis` directly (enforced by depguard). -- Frontend: Vue SFC + TypeScript, 2-space indentation, ESLint rules from `frontend/.eslintrc.cjs`. -- Naming: components use `PascalCase.vue`, composables use `useXxx.ts`, Go tests use `*_test.go`, frontend tests use `*.spec.ts`. - -## Go & Frontend Development Standards -- Control branch complexity: `if` nesting must not exceed 3 levels. Refactor with guard clauses, early returns, helper functions, or strategy maps when deeper logic appears. -- JSON hot-path rule: for read-only/partial-field extraction, prefer `gjson` over full `encoding/json` struct unmarshal to reduce allocations and improve latency. -- Exception rule: if full schema validation or typed writes are required, `encoding/json` is allowed, but PR must explain why `gjson` is not suitable. - -### Go Performance Rules -- Optimization workflow rule: benchmark/profile first, then optimize. Use `go test -bench`, `go tool pprof`, and runtime diagnostics before changing hot-path code. -- For hot functions, run escape analysis (`go build -gcflags=all='-m -m'`) and prioritize stack allocation where reasonable. -- Every external I/O path must use `context.Context` with explicit timeout/cancel. -- When creating derived contexts (`WithTimeout` / `WithDeadline`), always `defer cancel()` to release resources. -- Preallocate slices/maps when size can be estimated (`make([]T, 0, n)`, `make(map[K]V, n)`). -- Avoid unnecessary allocations in loops; reuse buffers and prefer `strings.Builder`/`bytes.Buffer`. -- Prohibit N+1 query patterns; batch DB/Redis operations and verify indexes for new query paths. -- For hot-path changes, include benchmark or latency comparison evidence (e.g., `go test -bench` before/after). -- Keep goroutine growth bounded (worker pool/semaphore), and avoid unbounded fan-out. -- Lock minimization rule: if a lock can be avoided, do not use a lock. Prefer ownership transfer (channel), sharding, immutable snapshots, copy-on-write, or atomic operations to reduce contention. -- When locks are unavoidable, keep critical sections minimal, avoid nested locks, and document why lock-free alternatives are not feasible. -- Follow `sync` guidance: prefer channels for higher-level synchronization; use low-level mutex primitives only where necessary. -- Avoid reflection and `interface{}`-heavy conversions in hot paths; use typed structs/functions. -- Use `sync.Pool` only when benchmark proves allocation reduction; remove if no measurable gain. -- Avoid repeated `time.Now()`/`fmt.Sprintf` in tight loops; hoist or cache when possible. -- For stable high-traffic binaries, maintain representative `default.pgo` profiles and keep `go build -pgo=auto` enabled. - -### Data Access & Cache Rules -- Every new/changed SQL query must be checked with `EXPLAIN` (or `EXPLAIN ANALYZE` in staging) and include index rationale in PR. -- Default to keyset pagination for large tables; avoid deep `OFFSET` scans on hot endpoints. -- Query only required columns; prohibit broad `SELECT *` in latency-sensitive paths. -- Keep transactions short; never perform external RPC/network calls inside DB transactions. -- Connection pool must be explicitly tuned and observed via `DB.Stats` (`SetMaxOpenConns`, `SetMaxIdleConns`, `SetConnMaxIdleTime`, `SetConnMaxLifetime`). -- Avoid overly small `MaxOpenConns` that can turn DB access into lock/semaphore bottlenecks. -- Cache keys must be versioned (e.g., `user_usage:v2:{id}`) and TTL should include jitter to avoid thundering herd. -- Use request coalescing (`singleflight` or equivalent) for high-concurrency cache miss paths. - -### Frontend Performance Rules -- Route-level and heavy-module code splitting is required; lazy-load non-critical views/components. -- API requests must support cancellation and deduplication; use debounce/throttle for search-like inputs. -- Minimize unnecessary reactivity: avoid deep watch chains when computed/cache can solve it. -- Prefer stable props and selective rendering controls (`v-once`, `v-memo`) for expensive subtrees when data is static or keyed. -- Large data rendering must use pagination or virtualization (especially tables/lists >200 rows). -- Move expensive CPU work off the main thread (Web Worker) or chunk tasks to avoid UI blocking. -- Keep bundle growth controlled; avoid adding heavy dependencies without clear ROI and alternatives review. -- Avoid expensive inline computations in templates; move to cached `computed` selectors. -- Keep state normalized; avoid duplicated derived state across multiple stores/components. -- Load charts/editors/export libraries on demand only (`dynamic import`) instead of app-entry import. -- Core Web Vitals targets (p75): `LCP <= 2.5s`, `INP <= 200ms`, `CLS <= 0.1`. -- Main-thread task budget: keep individual tasks below ~50ms; split long tasks and yield between chunks. -- Enforce frontend budgets in CI (Lighthouse CI with `budget.json`) for critical routes. - -### Performance Budget & PR Evidence -- Performance budget is mandatory for hot-path PRs: backend p95/p99 latency and CPU/memory must not regress by more than 5% versus baseline. -- Frontend budget: new route-level JS should not increase by more than 30KB gzip without explicit approval. -- For any gateway/protocol hot path, attach a reproducible benchmark command and results (input size, concurrency, before/after table). -- Profiling evidence is required for major optimizations (`pprof`, flamegraph, browser performance trace, or bundle analyzer output). - -### Quality Gate -- Any changed code must include new or updated unit tests. -- Coverage must stay above 85% (global frontend threshold and no regressions for touched backend modules). -- If any rule is intentionally violated, document reason, risk, and mitigation in the PR description. - -## Testing Guidelines -- Backend suites: `go test -tags=unit ./...`, `go test -tags=integration ./...`, and e2e where relevant. -- Frontend uses Vitest (`jsdom`); keep tests near modules (`__tests__`) or as `*.spec.ts`. -- Enforce unit-test and coverage rules defined in `Quality Gate`. -- Before opening a PR, run `make test` plus targeted tests for touched areas. - -## Commit & Pull Request Guidelines -- Follow Conventional Commits: `feat(scope): ...`, `fix(scope): ...`, `chore(scope): ...`, `docs(scope): ...`. -- PRs should include a clear summary, linked issue/spec, commands run for verification, and screenshots/GIFs for UI changes. -- For behavior/API changes, add or update `openspec/changes/...` artifacts. -- If dependencies change, commit `frontend/pnpm-lock.yaml` in the same PR. - -## Security & Configuration Tips -- Use `deploy/.env.example` and `deploy/config.example.yaml` as templates; do not commit real credentials. -- Set stable `JWT_SECRET`, `TOTP_ENCRYPTION_KEY`, and strong database passwords outside local dev. diff --git a/Dockerfile b/Dockerfile index 1493e8a7..8fd48cc2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,8 +7,9 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.25.7-alpine +ARG GOLANG_IMAGE=golang:1.26.1-alpine ARG ALPINE_IMAGE=alpine:3.21 +ARG POSTGRES_IMAGE=postgres:18-alpine ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn @@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \ ./cmd/server # ----------------------------------------------------------------------------- -# Stage 3: Final Runtime Image +# Stage 3: PostgreSQL Client (version-matched with docker-compose) +# ----------------------------------------------------------------------------- +FROM ${POSTGRES_IMAGE} AS pg-client + +# ----------------------------------------------------------------------------- +# Stage 4: Final Runtime Image # ----------------------------------------------------------------------------- FROM ${ALPINE_IMAGE} @@ -86,8 +92,20 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" RUN apk add --no-cache \ ca-certificates \ tzdata \ + libpq \ + zstd-libs \ + lz4-libs \ + krb5-libs \ + libldap \ + libedit \ && rm -rf /var/cache/apk/* +# Copy pg_dump and psql from the same postgres image used in docker-compose +# This ensures version consistency between backup tools and the database server +COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump +COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql +COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/ + # Create non-root user RUN addgroup -g 1000 sub2api && \ adduser -u 1000 -G sub2api -s /bin/sh -D sub2api diff --git a/Dockerfile.goreleaser b/Dockerfile.goreleaser index 2242c162..419994b9 100644 --- a/Dockerfile.goreleaser +++ b/Dockerfile.goreleaser @@ -5,7 +5,12 @@ # It only packages the pre-built binary, no compilation needed. # ============================================================================= -FROM alpine:3.19 +ARG ALPINE_IMAGE=alpine:3.21 +ARG POSTGRES_IMAGE=postgres:18-alpine + +FROM ${POSTGRES_IMAGE} AS pg-client + +FROM ${ALPINE_IMAGE} LABEL maintainer="Wei-Shaw " LABEL description="Sub2API - AI API Gateway Platform" @@ -16,8 +21,20 @@ RUN apk add --no-cache \ ca-certificates \ tzdata \ curl \ + libpq \ + zstd-libs \ + lz4-libs \ + krb5-libs \ + libldap \ + libedit \ && rm -rf /var/cache/apk/* +# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and +# restore work in the runtime container without requiring Docker socket access. +COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump +COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql +COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/ + # Create non-root user RUN addgroup -g 1000 sub2api && \ adduser -u 1000 -G sub2api -s /bin/sh -D sub2api diff --git a/README.md b/README.md index 1e2f2290..4a7bde8e 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,16 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot - **Concurrency Control** - Per-user and per-account concurrency limits - **Rate Limiting** - Configurable request and token rate limits - **Admin Dashboard** - Web interface for monitoring and management +- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard + +## Ecosystem + +Community projects that extend or integrate with Sub2API: + +| Project | Description | Features | +|---------|-------------|----------| +| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe | +| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native | ## Tech Stack @@ -150,14 +160,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash # Start services -docker-compose -f docker-compose.local.yml up -d +docker-compose up -d # View logs -docker-compose -f docker-compose.local.yml logs -f sub2api +docker-compose logs -f sub2api ``` **What the script does:** -- Downloads `docker-compose.local.yml` and `.env.example` +- Downloads `docker-compose.local.yml` (saved as `docker-compose.yml`) and `.env.example` - Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD) - Creates `.env` file with auto-generated secrets - Creates data directories (uses local directories for easy backup/migration) @@ -522,6 +532,28 @@ sub2api/ └── install.sh # One-click installation script ``` +## Disclaimer + +> **Please read carefully before using this project:** +> +> :rotating_light: **Terms of Service Risk**: Using this project may violate Anthropic's Terms of Service. Please read Anthropic's user agreement carefully before use. All risks arising from the use of this project are borne solely by the user. +> +> :book: **Disclaimer**: This project is for technical learning and research purposes only. The author assumes no responsibility for account suspension, service interruption, or any other losses caused by the use of this project. + +--- + +## Star History + + + + + + Star History Chart + + + +--- + ## License MIT License diff --git a/README_CN.md b/README_CN.md index 9da089b7..eee89b07 100644 --- a/README_CN.md +++ b/README_CN.md @@ -39,6 +39,16 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( - **并发控制** - 用户级和账号级并发限制 - **速率限制** - 可配置的请求和 Token 速率限制 - **管理后台** - Web 界面进行监控和管理 +- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能 + +## 生态项目 + +围绕 Sub2API 的社区扩展与集成项目: + +| 项目 | 说明 | 功能 | +|------|------|------| +| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买;兼容易支付协议、微信官方支付、支付宝官方支付、Stripe;支持 iframe 嵌入管理后台 | +| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 | ## 技术栈 @@ -137,8 +147,6 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install 使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。 -如果你的服务器是 **Ubuntu 24.04**,建议直接参考:`deploy/ubuntu24-docker-compose-aicodex.md`,其中包含「安装最新版 Docker + docker-compose-aicodex.yml 部署」的完整步骤。 - #### 前置条件 - Docker 20.10+ @@ -156,14 +164,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash # 启动服务 -docker-compose -f docker-compose.local.yml up -d +docker-compose up -d # 查看日志 -docker-compose -f docker-compose.local.yml logs -f sub2api +docker-compose logs -f sub2api ``` **脚本功能:** -- 下载 `docker-compose.local.yml` 和 `.env.example` +- 下载 `docker-compose.local.yml`(本地保存为 `docker-compose.yml`)和 `.env.example` - 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD) - 创建 `.env` 文件并填充自动生成的密钥 - 创建数据目录(使用本地目录,便于备份和迁移) @@ -590,6 +598,28 @@ sub2api/ └── install.sh # 一键安装脚本 ``` +## 免责声明 + +> **使用本项目前请仔细阅读:** +> +> :rotating_light: **服务条款风险**: 使用本项目可能违反 Anthropic 的服务条款。请在使用前仔细阅读 Anthropic 的用户协议,使用本项目的一切风险由用户自行承担。 +> +> :book: **免责声明**: 本项目仅供技术学习和研究使用,作者不对因使用本项目导致的账户封禁、服务中断或其他损失承担任何责任。 + +--- + +## Star History + + + + + + Star History Chart + + + +--- + ## 许可证 MIT License diff --git a/backend/.golangci.yml b/backend/.golangci.yml index 68b76751..92ba3916 100644 --- a/backend/.golangci.yml +++ b/backend/.golangci.yml @@ -93,20 +93,13 @@ linters: check-escaping-errors: true staticcheck: # https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist - # Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"] dot-import-whitelist: - fmt # https://staticcheck.dev/docs/configuration/options/#initialisms - # Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"] initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ] # https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist - # Default: ["200", "400", "404", "500"] http-status-code-whitelist: [ "200", "400", "404", "500" ] - # SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks - # Example (to disable some checks): [ "all", "-SA1000", "-SA1001"] - # Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks. - # Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"] - # Temporarily disable style checks to allow CI to pass + # "all" enables every SA/ST/S/QF check; only list the ones to disable. checks: - all - -ST1000 # Package comment format @@ -114,489 +107,19 @@ linters: - -ST1020 # Comment on exported method format - -ST1021 # Comment on exported type format - -ST1022 # Comment on exported variable format - # Invalid regular expression. - # https://staticcheck.dev/docs/checks/#SA1000 - - SA1000 - # Invalid template. - # https://staticcheck.dev/docs/checks/#SA1001 - - SA1001 - # Invalid format in 'time.Parse'. - # https://staticcheck.dev/docs/checks/#SA1002 - - SA1002 - # Unsupported argument to functions in 'encoding/binary'. - # https://staticcheck.dev/docs/checks/#SA1003 - - SA1003 - # Suspiciously small untyped constant in 'time.Sleep'. - # https://staticcheck.dev/docs/checks/#SA1004 - - SA1004 - # Invalid first argument to 'exec.Command'. - # https://staticcheck.dev/docs/checks/#SA1005 - - SA1005 - # 'Printf' with dynamic first argument and no further arguments. - # https://staticcheck.dev/docs/checks/#SA1006 - - SA1006 - # Invalid URL in 'net/url.Parse'. - # https://staticcheck.dev/docs/checks/#SA1007 - - SA1007 - # Non-canonical key in 'http.Header' map. - # https://staticcheck.dev/docs/checks/#SA1008 - - SA1008 - # '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results. - # https://staticcheck.dev/docs/checks/#SA1010 - - SA1010 - # Various methods in the "strings" package expect valid UTF-8, but invalid input is provided. - # https://staticcheck.dev/docs/checks/#SA1011 - - SA1011 - # A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead. - # https://staticcheck.dev/docs/checks/#SA1012 - - SA1012 - # 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second. - # https://staticcheck.dev/docs/checks/#SA1013 - - SA1013 - # Non-pointer value passed to 'Unmarshal' or 'Decode'. - # https://staticcheck.dev/docs/checks/#SA1014 - - SA1014 - # Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions. - # https://staticcheck.dev/docs/checks/#SA1015 - - SA1015 - # Trapping a signal that cannot be trapped. - # https://staticcheck.dev/docs/checks/#SA1016 - - SA1016 - # Channels used with 'os/signal.Notify' should be buffered. - # https://staticcheck.dev/docs/checks/#SA1017 - - SA1017 - # 'strings.Replace' called with 'n == 0', which does nothing. - # https://staticcheck.dev/docs/checks/#SA1018 - - SA1018 - # Using a deprecated function, variable, constant or field. - # https://staticcheck.dev/docs/checks/#SA1019 - - SA1019 - # Using an invalid host:port pair with a 'net.Listen'-related function. - # https://staticcheck.dev/docs/checks/#SA1020 - - SA1020 - # Using 'bytes.Equal' to compare two 'net.IP'. - # https://staticcheck.dev/docs/checks/#SA1021 - - SA1021 - # Modifying the buffer in an 'io.Writer' implementation. - # https://staticcheck.dev/docs/checks/#SA1023 - - SA1023 - # A string cutset contains duplicate characters. - # https://staticcheck.dev/docs/checks/#SA1024 - - SA1024 - # It is not possible to use '(*time.Timer).Reset''s return value correctly. - # https://staticcheck.dev/docs/checks/#SA1025 - - SA1025 - # Cannot marshal channels or functions. - # https://staticcheck.dev/docs/checks/#SA1026 - - SA1026 - # Atomic access to 64-bit variable must be 64-bit aligned. - # https://staticcheck.dev/docs/checks/#SA1027 - - SA1027 - # 'sort.Slice' can only be used on slices. - # https://staticcheck.dev/docs/checks/#SA1028 - - SA1028 - # Inappropriate key in call to 'context.WithValue'. - # https://staticcheck.dev/docs/checks/#SA1029 - - SA1029 - # Invalid argument in call to a 'strconv' function. - # https://staticcheck.dev/docs/checks/#SA1030 - - SA1030 - # Overlapping byte slices passed to an encoder. - # https://staticcheck.dev/docs/checks/#SA1031 - - SA1031 - # Wrong order of arguments to 'errors.Is'. - # https://staticcheck.dev/docs/checks/#SA1032 - - SA1032 - # 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition. - # https://staticcheck.dev/docs/checks/#SA2000 - - SA2000 - # Empty critical section, did you mean to defer the unlock?. - # https://staticcheck.dev/docs/checks/#SA2001 - - SA2001 - # Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed. - # https://staticcheck.dev/docs/checks/#SA2002 - - SA2002 - # Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead. - # https://staticcheck.dev/docs/checks/#SA2003 - - SA2003 - # 'TestMain' doesn't call 'os.Exit', hiding test failures. - # https://staticcheck.dev/docs/checks/#SA3000 - - SA3000 - # Assigning to 'b.N' in benchmarks distorts the results. - # https://staticcheck.dev/docs/checks/#SA3001 - - SA3001 - # Binary operator has identical expressions on both sides. - # https://staticcheck.dev/docs/checks/#SA4000 - - SA4000 - # '&*x' gets simplified to 'x', it does not copy 'x'. - # https://staticcheck.dev/docs/checks/#SA4001 - - SA4001 - # Comparing unsigned values against negative values is pointless. - # https://staticcheck.dev/docs/checks/#SA4003 - - SA4003 - # The loop exits unconditionally after one iteration. - # https://staticcheck.dev/docs/checks/#SA4004 - - SA4004 - # Field assignment that will never be observed. Did you mean to use a pointer receiver?. - # https://staticcheck.dev/docs/checks/#SA4005 - - SA4005 - # A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?. - # https://staticcheck.dev/docs/checks/#SA4006 - - SA4006 - # The variable in the loop condition never changes, are you incrementing the wrong variable?. - # https://staticcheck.dev/docs/checks/#SA4008 - - SA4008 - # A function argument is overwritten before its first use. - # https://staticcheck.dev/docs/checks/#SA4009 - - SA4009 - # The result of 'append' will never be observed anywhere. - # https://staticcheck.dev/docs/checks/#SA4010 - - SA4010 - # Break statement with no effect. Did you mean to break out of an outer loop?. - # https://staticcheck.dev/docs/checks/#SA4011 - - SA4011 - # Comparing a value against NaN even though no value is equal to NaN. - # https://staticcheck.dev/docs/checks/#SA4012 - - SA4012 - # Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo. - # https://staticcheck.dev/docs/checks/#SA4013 - - SA4013 - # An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either. - # https://staticcheck.dev/docs/checks/#SA4014 - - SA4014 - # Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful. - # https://staticcheck.dev/docs/checks/#SA4015 - - SA4015 - # Certain bitwise operations, such as 'x ^ 0', do not do anything useful. - # https://staticcheck.dev/docs/checks/#SA4016 - - SA4016 - # Discarding the return values of a function without side effects, making the call pointless. - # https://staticcheck.dev/docs/checks/#SA4017 - - SA4017 - # Self-assignment of variables. - # https://staticcheck.dev/docs/checks/#SA4018 - - SA4018 - # Multiple, identical build constraints in the same file. - # https://staticcheck.dev/docs/checks/#SA4019 - - SA4019 - # Unreachable case clause in a type switch. - # https://staticcheck.dev/docs/checks/#SA4020 - - SA4020 - # "x = append(y)" is equivalent to "x = y". - # https://staticcheck.dev/docs/checks/#SA4021 - - SA4021 - # Comparing the address of a variable against nil. - # https://staticcheck.dev/docs/checks/#SA4022 - - SA4022 - # Impossible comparison of interface value with untyped nil. - # https://staticcheck.dev/docs/checks/#SA4023 - - SA4023 - # Checking for impossible return value from a builtin function. - # https://staticcheck.dev/docs/checks/#SA4024 - - SA4024 - # Integer division of literals that results in zero. - # https://staticcheck.dev/docs/checks/#SA4025 - - SA4025 - # Go constants cannot express negative zero. - # https://staticcheck.dev/docs/checks/#SA4026 - - SA4026 - # '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL. - # https://staticcheck.dev/docs/checks/#SA4027 - - SA4027 - # 'x % 1' is always zero. - # https://staticcheck.dev/docs/checks/#SA4028 - - SA4028 - # Ineffective attempt at sorting slice. - # https://staticcheck.dev/docs/checks/#SA4029 - - SA4029 - # Ineffective attempt at generating random number. - # https://staticcheck.dev/docs/checks/#SA4030 - - SA4030 - # Checking never-nil value against nil. - # https://staticcheck.dev/docs/checks/#SA4031 - - SA4031 - # Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value. - # https://staticcheck.dev/docs/checks/#SA4032 - - SA4032 - # Assignment to nil map. - # https://staticcheck.dev/docs/checks/#SA5000 - - SA5000 - # Deferring 'Close' before checking for a possible error. - # https://staticcheck.dev/docs/checks/#SA5001 - - SA5001 - # The empty for loop ("for {}") spins and can block the scheduler. - # https://staticcheck.dev/docs/checks/#SA5002 - - SA5002 - # Defers in infinite loops will never execute. - # https://staticcheck.dev/docs/checks/#SA5003 - - SA5003 - # "for { select { ..." with an empty default branch spins. - # https://staticcheck.dev/docs/checks/#SA5004 - - SA5004 - # The finalizer references the finalized object, preventing garbage collection. - # https://staticcheck.dev/docs/checks/#SA5005 - - SA5005 - # Infinite recursive call. - # https://staticcheck.dev/docs/checks/#SA5007 - - SA5007 - # Invalid struct tag. - # https://staticcheck.dev/docs/checks/#SA5008 - - SA5008 - # Invalid Printf call. - # https://staticcheck.dev/docs/checks/#SA5009 - - SA5009 - # Impossible type assertion. - # https://staticcheck.dev/docs/checks/#SA5010 - - SA5010 - # Possible nil pointer dereference. - # https://staticcheck.dev/docs/checks/#SA5011 - - SA5011 - # Passing odd-sized slice to function expecting even size. - # https://staticcheck.dev/docs/checks/#SA5012 - - SA5012 - # Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'. - # https://staticcheck.dev/docs/checks/#SA6000 - - SA6000 - # Missing an optimization opportunity when indexing maps by byte slices. - # https://staticcheck.dev/docs/checks/#SA6001 - - SA6001 - # Storing non-pointer values in 'sync.Pool' allocates memory. - # https://staticcheck.dev/docs/checks/#SA6002 - - SA6002 - # Converting a string to a slice of runes before ranging over it. - # https://staticcheck.dev/docs/checks/#SA6003 - - SA6003 - # Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'. - # https://staticcheck.dev/docs/checks/#SA6005 - - SA6005 - # Using io.WriteString to write '[]byte'. - # https://staticcheck.dev/docs/checks/#SA6006 - - SA6006 - # Defers in range loops may not run when you expect them to. - # https://staticcheck.dev/docs/checks/#SA9001 - - SA9001 - # Using a non-octal 'os.FileMode' that looks like it was meant to be in octal. - # https://staticcheck.dev/docs/checks/#SA9002 - - SA9002 - # Empty body in an if or else branch. - # https://staticcheck.dev/docs/checks/#SA9003 - - SA9003 - # Only the first constant has an explicit type. - # https://staticcheck.dev/docs/checks/#SA9004 - - SA9004 - # Trying to marshal a struct with no public fields nor custom marshaling. - # https://staticcheck.dev/docs/checks/#SA9005 - - SA9005 - # Dubious bit shifting of a fixed size integer value. - # https://staticcheck.dev/docs/checks/#SA9006 - - SA9006 - # Deleting a directory that shouldn't be deleted. - # https://staticcheck.dev/docs/checks/#SA9007 - - SA9007 - # 'else' branch of a type assertion is probably not reading the right value. - # https://staticcheck.dev/docs/checks/#SA9008 - - SA9008 - # Ineffectual Go compiler directive. - # https://staticcheck.dev/docs/checks/#SA9009 - - SA9009 - # NOTE: ST1000, ST1001, ST1003, ST1020, ST1021, ST1022 are disabled above - # Incorrectly formatted error string. - # https://staticcheck.dev/docs/checks/#ST1005 - - ST1005 - # Poorly chosen receiver name. - # https://staticcheck.dev/docs/checks/#ST1006 - - ST1006 - # A function's error value should be its last return value. - # https://staticcheck.dev/docs/checks/#ST1008 - - ST1008 - # Poorly chosen name for variable of type 'time.Duration'. - # https://staticcheck.dev/docs/checks/#ST1011 - - ST1011 - # Poorly chosen name for error variable. - # https://staticcheck.dev/docs/checks/#ST1012 - - ST1012 - # Should use constants for HTTP error codes, not magic numbers. - # https://staticcheck.dev/docs/checks/#ST1013 - - ST1013 - # A switch's default case should be the first or last case. - # https://staticcheck.dev/docs/checks/#ST1015 - - ST1015 - # Use consistent method receiver names. - # https://staticcheck.dev/docs/checks/#ST1016 - - ST1016 - # Don't use Yoda conditions. - # https://staticcheck.dev/docs/checks/#ST1017 - - ST1017 - # Avoid zero-width and control characters in string literals. - # https://staticcheck.dev/docs/checks/#ST1018 - - ST1018 - # Importing the same package multiple times. - # https://staticcheck.dev/docs/checks/#ST1019 - - ST1019 - # NOTE: ST1020, ST1021, ST1022 removed (disabled above) - # Redundant type in variable declaration. - # https://staticcheck.dev/docs/checks/#ST1023 - - ST1023 - # Use plain channel send or receive instead of single-case select. - # https://staticcheck.dev/docs/checks/#S1000 - - S1000 - # Replace for loop with call to copy. - # https://staticcheck.dev/docs/checks/#S1001 - - S1001 - # Omit comparison with boolean constant. - # https://staticcheck.dev/docs/checks/#S1002 - - S1002 - # Replace call to 'strings.Index' with 'strings.Contains'. - # https://staticcheck.dev/docs/checks/#S1003 - - S1003 - # Replace call to 'bytes.Compare' with 'bytes.Equal'. - # https://staticcheck.dev/docs/checks/#S1004 - - S1004 - # Drop unnecessary use of the blank identifier. - # https://staticcheck.dev/docs/checks/#S1005 - - S1005 - # Use "for { ... }" for infinite loops. - # https://staticcheck.dev/docs/checks/#S1006 - - S1006 - # Simplify regular expression by using raw string literal. - # https://staticcheck.dev/docs/checks/#S1007 - - S1007 - # Simplify returning boolean expression. - # https://staticcheck.dev/docs/checks/#S1008 - - S1008 - # Omit redundant nil check on slices, maps, and channels. - # https://staticcheck.dev/docs/checks/#S1009 - - S1009 - # Omit default slice index. - # https://staticcheck.dev/docs/checks/#S1010 - - S1010 - # Use a single 'append' to concatenate two slices. - # https://staticcheck.dev/docs/checks/#S1011 - - S1011 - # Replace 'time.Now().Sub(x)' with 'time.Since(x)'. - # https://staticcheck.dev/docs/checks/#S1012 - - S1012 - # Use a type conversion instead of manually copying struct fields. - # https://staticcheck.dev/docs/checks/#S1016 - - S1016 - # Replace manual trimming with 'strings.TrimPrefix'. - # https://staticcheck.dev/docs/checks/#S1017 - - S1017 - # Use "copy" for sliding elements. - # https://staticcheck.dev/docs/checks/#S1018 - - S1018 - # Simplify "make" call by omitting redundant arguments. - # https://staticcheck.dev/docs/checks/#S1019 - - S1019 - # Omit redundant nil check in type assertion. - # https://staticcheck.dev/docs/checks/#S1020 - - S1020 - # Merge variable declaration and assignment. - # https://staticcheck.dev/docs/checks/#S1021 - - S1021 - # Omit redundant control flow. - # https://staticcheck.dev/docs/checks/#S1023 - - S1023 - # Replace 'x.Sub(time.Now())' with 'time.Until(x)'. - # https://staticcheck.dev/docs/checks/#S1024 - - S1024 - # Don't use 'fmt.Sprintf("%s", x)' unnecessarily. - # https://staticcheck.dev/docs/checks/#S1025 - - S1025 - # Simplify error construction with 'fmt.Errorf'. - # https://staticcheck.dev/docs/checks/#S1028 - - S1028 - # Range over the string directly. - # https://staticcheck.dev/docs/checks/#S1029 - - S1029 - # Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'. - # https://staticcheck.dev/docs/checks/#S1030 - - S1030 - # Omit redundant nil check around loop. - # https://staticcheck.dev/docs/checks/#S1031 - - S1031 - # Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'. - # https://staticcheck.dev/docs/checks/#S1032 - - S1032 - # Unnecessary guard around call to "delete". - # https://staticcheck.dev/docs/checks/#S1033 - - S1033 - # Use result of type assertion to simplify cases. - # https://staticcheck.dev/docs/checks/#S1034 - - S1034 - # Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'. - # https://staticcheck.dev/docs/checks/#S1035 - - S1035 - # Unnecessary guard around map access. - # https://staticcheck.dev/docs/checks/#S1036 - - S1036 - # Elaborate way of sleeping. - # https://staticcheck.dev/docs/checks/#S1037 - - S1037 - # Unnecessarily complex way of printing formatted string. - # https://staticcheck.dev/docs/checks/#S1038 - - S1038 - # Unnecessary use of 'fmt.Sprint'. - # https://staticcheck.dev/docs/checks/#S1039 - - S1039 - # Type assertion to current type. - # https://staticcheck.dev/docs/checks/#S1040 - - S1040 - # Apply De Morgan's law. - # https://staticcheck.dev/docs/checks/#QF1001 - - QF1001 - # Convert untagged switch to tagged switch. - # https://staticcheck.dev/docs/checks/#QF1002 - - QF1002 - # Convert if/else-if chain to tagged switch. - # https://staticcheck.dev/docs/checks/#QF1003 - - QF1003 - # Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'. - # https://staticcheck.dev/docs/checks/#QF1004 - - QF1004 - # Expand call to 'math.Pow'. - # https://staticcheck.dev/docs/checks/#QF1005 - - QF1005 - # Lift 'if'+'break' into loop condition. - # https://staticcheck.dev/docs/checks/#QF1006 - - QF1006 - # Merge conditional assignment into variable declaration. - # https://staticcheck.dev/docs/checks/#QF1007 - - QF1007 - # Omit embedded fields from selector expression. - # https://staticcheck.dev/docs/checks/#QF1008 - - QF1008 - # Use 'time.Time.Equal' instead of '==' operator. - # https://staticcheck.dev/docs/checks/#QF1009 - - QF1009 - # Convert slice of bytes to string when printing it. - # https://staticcheck.dev/docs/checks/#QF1010 - - QF1010 - # Omit redundant type from variable declaration. - # https://staticcheck.dev/docs/checks/#QF1011 - - QF1011 - # Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'. - # https://staticcheck.dev/docs/checks/#QF1012 - - QF1012 unused: - # Mark all struct fields that have been written to as used. # Default: true - field-writes-are-uses: false - # Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write. + field-writes-are-uses: true # Default: false post-statements-are-reads: true - # Mark all exported fields as used. - # default: true - exported-fields-are-used: false - # Mark all function parameters as used. - # default: true - parameters-are-used: true - # Mark all local variables as used. - # default: true - local-variables-are-used: false - # Mark all identifiers inside generated files as used. # Default: true - generated-is-used: false + exported-fields-are-used: true + # Default: true + parameters-are-used: true + # Default: true + local-variables-are-used: false + # Default: true — must be true, ent generates 130K+ lines of code + generated-is-used: true formatters: enable: diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index bc001693..7eabde62 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 15fdb0ba..d07b3832 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -100,7 +100,7 @@ func runSetupServer() { r := gin.New() r.Use(middleware.Recovery()) r.Use(middleware.CORS(config.CORSConfig{})) - r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy})) + r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}, nil)) // Register setup routes setup.RegisterRoutes(r) diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index cbf89ba3..7fc648ac 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -41,6 +41,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { // Server layer ProviderSet server.ProviderSet, + // Privacy client factory for OpenAI training opt-out + providePrivacyClientFactory, + // BuildInfo provider provideServiceBuildInfo, @@ -53,6 +56,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { return nil, nil } +func providePrivacyClientFactory() service.PrivacyClientFactory { + return repository.CreatePrivacyReqClient +} + func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo { return service.BuildInfo{ Version: buildInfo.Version, @@ -86,6 +93,8 @@ func provideCleanup( geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, openAIGateway *service.OpenAIGatewayService, + scheduledTestRunner *service.ScheduledTestRunnerService, + backupSvc *service.BackupService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -216,6 +225,18 @@ func provideCleanup( } return nil }}, + {"ScheduledTestRunnerService", func() error { + if scheduledTestRunner != nil { + scheduledTestRunner.Stop() + } + return nil + }}, + {"BackupService", func() error { + if backupSvc != nil { + backupSvc.Stop() + } + return nil + }}, } infraSteps := []cleanupStep{ diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 90709f5b..f632bff3 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -58,15 +58,16 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { promoCodeRepository := repository.NewPromoCodeRepository(client) billingCache := repository.NewBillingCache(redisClient) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) - billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) - apiKeyRepository := repository.NewAPIKeyRepository(client) + apiKeyRepository := repository.NewAPIKeyRepository(client, db) + billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig) userGroupRateRepository := repository.NewUserGroupRateRepository(db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) + apiKeyService.SetRateLimitCacheInvalidator(billingCache) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) - authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) + authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) @@ -80,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) + usageBillingRepository := repository.NewUsageBillingRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemHandler := handler.NewRedeemHandler(redeemService) @@ -103,7 +105,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService) + privacyClientFactory := providePrivacyClientFactory() + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) @@ -121,6 +124,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tempUnschedCache := repository.NewTempUnschedCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) geminiTokenCache := repository.NewGeminiTokenCache(redisClient) + oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache) compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator) httpUpstream := repository.NewHTTPUpstream(configConfig) @@ -129,11 +133,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { usageCache := service.NewUsageCache() identityCache := repository.NewIdentityCache(redisClient) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache) - geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) + geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI) gatewayCache := repository.NewGatewayCache(redisClient) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) - antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) + antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) @@ -143,6 +147,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) dataManagementService := service.NewDataManagementService() dataManagementHandler := admin.NewDataManagementHandler(dataManagementService) + backupObjectStoreFactory := repository.NewS3BackupStoreFactory() + dbDumper := repository.NewPgDumper(configConfig) + backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper) + backupHandler := admin.NewBackupHandler(backupService, userService) oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) @@ -159,11 +167,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingService := service.NewBillingService(configConfig, pricingService) identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) - claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) + claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore) - openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService) + openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) @@ -194,9 +202,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler) + scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db) + scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db) + scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) + scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig, settingService) + userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) + userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) @@ -219,10 +233,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService) + scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService) application := &Application{ Server: httpServer, Cleanup: v, @@ -237,6 +252,10 @@ type Application struct { Cleanup func() } +func providePrivacyClientFactory() service.PrivacyClientFactory { + return repository.CreatePrivacyReqClient +} + func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo { return service.BuildInfo{ Version: buildInfo.Version, @@ -270,6 +289,8 @@ func provideCleanup( geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, openAIGateway *service.OpenAIGatewayService, + scheduledTestRunner *service.ScheduledTestRunnerService, + backupSvc *service.BackupService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -399,6 +420,18 @@ func provideCleanup( } return nil }}, + {"ScheduledTestRunnerService", func() error { + if scheduledTestRunner != nil { + scheduledTestRunner.Stop() + } + return nil + }}, + {"BackupService", func() error { + if backupSvc != nil { + backupSvc.Stop() + } + return nil + }}, } infraSteps := []cleanupStep{ diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go index 9fb9888d..9d2a54b9 100644 --- a/backend/cmd/server/wire_gen_test.go +++ b/backend/cmd/server/wire_gen_test.go @@ -37,12 +37,13 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { nil, nil, cfg, + nil, ) accountExpirySvc := service.NewAccountExpiryService(nil, time.Second) subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second) pricingSvc := service.NewPricingService(cfg, nil) emailQueueSvc := service.NewEmailQueueService(nil, 1) - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg) schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg) opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil) @@ -73,6 +74,8 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { geminiOAuthSvc, antigravityOAuthSvc, nil, // openAIGateway + nil, // scheduledTestRunner + nil, // backupSvc ) require.NotPanics(t, func() { diff --git a/backend/ent/account.go b/backend/ent/account.go index c77002b3..2dbfc3a2 100644 --- a/backend/ent/account.go +++ b/backend/ent/account.go @@ -41,6 +41,8 @@ type Account struct { ProxyID *int64 `json:"proxy_id,omitempty"` // Concurrency holds the value of the "concurrency" field. Concurrency int `json:"concurrency,omitempty"` + // LoadFactor holds the value of the "load_factor" field. + LoadFactor *int `json:"load_factor,omitempty"` // Priority holds the value of the "priority" field. Priority int `json:"priority,omitempty"` // RateMultiplier holds the value of the "rate_multiplier" field. @@ -143,7 +145,7 @@ func (*Account) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case account.FieldRateMultiplier: values[i] = new(sql.NullFloat64) - case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority: + case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldLoadFactor, account.FieldPriority: values[i] = new(sql.NullInt64) case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus: values[i] = new(sql.NullString) @@ -243,6 +245,13 @@ func (_m *Account) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Concurrency = int(value.Int64) } + case account.FieldLoadFactor: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field load_factor", values[i]) + } else if value.Valid { + _m.LoadFactor = new(int) + *_m.LoadFactor = int(value.Int64) + } case account.FieldPriority: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field priority", values[i]) @@ -445,6 +454,11 @@ func (_m *Account) String() string { builder.WriteString("concurrency=") builder.WriteString(fmt.Sprintf("%v", _m.Concurrency)) builder.WriteString(", ") + if v := _m.LoadFactor; v != nil { + builder.WriteString("load_factor=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") builder.WriteString("priority=") builder.WriteString(fmt.Sprintf("%v", _m.Priority)) builder.WriteString(", ") diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go index 1fc34620..4c134649 100644 --- a/backend/ent/account/account.go +++ b/backend/ent/account/account.go @@ -37,6 +37,8 @@ const ( FieldProxyID = "proxy_id" // FieldConcurrency holds the string denoting the concurrency field in the database. FieldConcurrency = "concurrency" + // FieldLoadFactor holds the string denoting the load_factor field in the database. + FieldLoadFactor = "load_factor" // FieldPriority holds the string denoting the priority field in the database. FieldPriority = "priority" // FieldRateMultiplier holds the string denoting the rate_multiplier field in the database. @@ -121,6 +123,7 @@ var Columns = []string{ FieldExtra, FieldProxyID, FieldConcurrency, + FieldLoadFactor, FieldPriority, FieldRateMultiplier, FieldStatus, @@ -250,6 +253,11 @@ func ByConcurrency(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldConcurrency, opts...).ToFunc() } +// ByLoadFactor orders the results by the load_factor field. +func ByLoadFactor(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLoadFactor, opts...).ToFunc() +} + // ByPriority orders the results by the priority field. func ByPriority(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldPriority, opts...).ToFunc() diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go index 54db1dcb..3749b45c 100644 --- a/backend/ent/account/where.go +++ b/backend/ent/account/where.go @@ -100,6 +100,11 @@ func Concurrency(v int) predicate.Account { return predicate.Account(sql.FieldEQ(FieldConcurrency, v)) } +// LoadFactor applies equality check predicate on the "load_factor" field. It's identical to LoadFactorEQ. +func LoadFactor(v int) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldLoadFactor, v)) +} + // Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ. func Priority(v int) predicate.Account { return predicate.Account(sql.FieldEQ(FieldPriority, v)) @@ -650,6 +655,56 @@ func ConcurrencyLTE(v int) predicate.Account { return predicate.Account(sql.FieldLTE(FieldConcurrency, v)) } +// LoadFactorEQ applies the EQ predicate on the "load_factor" field. +func LoadFactorEQ(v int) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldLoadFactor, v)) +} + +// LoadFactorNEQ applies the NEQ predicate on the "load_factor" field. +func LoadFactorNEQ(v int) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldLoadFactor, v)) +} + +// LoadFactorIn applies the In predicate on the "load_factor" field. +func LoadFactorIn(vs ...int) predicate.Account { + return predicate.Account(sql.FieldIn(FieldLoadFactor, vs...)) +} + +// LoadFactorNotIn applies the NotIn predicate on the "load_factor" field. +func LoadFactorNotIn(vs ...int) predicate.Account { + return predicate.Account(sql.FieldNotIn(FieldLoadFactor, vs...)) +} + +// LoadFactorGT applies the GT predicate on the "load_factor" field. +func LoadFactorGT(v int) predicate.Account { + return predicate.Account(sql.FieldGT(FieldLoadFactor, v)) +} + +// LoadFactorGTE applies the GTE predicate on the "load_factor" field. +func LoadFactorGTE(v int) predicate.Account { + return predicate.Account(sql.FieldGTE(FieldLoadFactor, v)) +} + +// LoadFactorLT applies the LT predicate on the "load_factor" field. +func LoadFactorLT(v int) predicate.Account { + return predicate.Account(sql.FieldLT(FieldLoadFactor, v)) +} + +// LoadFactorLTE applies the LTE predicate on the "load_factor" field. +func LoadFactorLTE(v int) predicate.Account { + return predicate.Account(sql.FieldLTE(FieldLoadFactor, v)) +} + +// LoadFactorIsNil applies the IsNil predicate on the "load_factor" field. +func LoadFactorIsNil() predicate.Account { + return predicate.Account(sql.FieldIsNull(FieldLoadFactor)) +} + +// LoadFactorNotNil applies the NotNil predicate on the "load_factor" field. +func LoadFactorNotNil() predicate.Account { + return predicate.Account(sql.FieldNotNull(FieldLoadFactor)) +} + // PriorityEQ applies the EQ predicate on the "priority" field. func PriorityEQ(v int) predicate.Account { return predicate.Account(sql.FieldEQ(FieldPriority, v)) diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go index 963ffee8..d6046c79 100644 --- a/backend/ent/account_create.go +++ b/backend/ent/account_create.go @@ -139,6 +139,20 @@ func (_c *AccountCreate) SetNillableConcurrency(v *int) *AccountCreate { return _c } +// SetLoadFactor sets the "load_factor" field. +func (_c *AccountCreate) SetLoadFactor(v int) *AccountCreate { + _c.mutation.SetLoadFactor(v) + return _c +} + +// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil. +func (_c *AccountCreate) SetNillableLoadFactor(v *int) *AccountCreate { + if v != nil { + _c.SetLoadFactor(*v) + } + return _c +} + // SetPriority sets the "priority" field. func (_c *AccountCreate) SetPriority(v int) *AccountCreate { _c.mutation.SetPriority(v) @@ -623,6 +637,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { _spec.SetField(account.FieldConcurrency, field.TypeInt, value) _node.Concurrency = value } + if value, ok := _c.mutation.LoadFactor(); ok { + _spec.SetField(account.FieldLoadFactor, field.TypeInt, value) + _node.LoadFactor = &value + } if value, ok := _c.mutation.Priority(); ok { _spec.SetField(account.FieldPriority, field.TypeInt, value) _node.Priority = value @@ -936,6 +954,30 @@ func (u *AccountUpsert) AddConcurrency(v int) *AccountUpsert { return u } +// SetLoadFactor sets the "load_factor" field. +func (u *AccountUpsert) SetLoadFactor(v int) *AccountUpsert { + u.Set(account.FieldLoadFactor, v) + return u +} + +// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create. +func (u *AccountUpsert) UpdateLoadFactor() *AccountUpsert { + u.SetExcluded(account.FieldLoadFactor) + return u +} + +// AddLoadFactor adds v to the "load_factor" field. +func (u *AccountUpsert) AddLoadFactor(v int) *AccountUpsert { + u.Add(account.FieldLoadFactor, v) + return u +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (u *AccountUpsert) ClearLoadFactor() *AccountUpsert { + u.SetNull(account.FieldLoadFactor) + return u +} + // SetPriority sets the "priority" field. func (u *AccountUpsert) SetPriority(v int) *AccountUpsert { u.Set(account.FieldPriority, v) @@ -1419,6 +1461,34 @@ func (u *AccountUpsertOne) UpdateConcurrency() *AccountUpsertOne { }) } +// SetLoadFactor sets the "load_factor" field. +func (u *AccountUpsertOne) SetLoadFactor(v int) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetLoadFactor(v) + }) +} + +// AddLoadFactor adds v to the "load_factor" field. +func (u *AccountUpsertOne) AddLoadFactor(v int) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.AddLoadFactor(v) + }) +} + +// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateLoadFactor() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateLoadFactor() + }) +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (u *AccountUpsertOne) ClearLoadFactor() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.ClearLoadFactor() + }) +} + // SetPriority sets the "priority" field. func (u *AccountUpsertOne) SetPriority(v int) *AccountUpsertOne { return u.Update(func(s *AccountUpsert) { @@ -2113,6 +2183,34 @@ func (u *AccountUpsertBulk) UpdateConcurrency() *AccountUpsertBulk { }) } +// SetLoadFactor sets the "load_factor" field. +func (u *AccountUpsertBulk) SetLoadFactor(v int) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetLoadFactor(v) + }) +} + +// AddLoadFactor adds v to the "load_factor" field. +func (u *AccountUpsertBulk) AddLoadFactor(v int) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.AddLoadFactor(v) + }) +} + +// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateLoadFactor() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateLoadFactor() + }) +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (u *AccountUpsertBulk) ClearLoadFactor() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.ClearLoadFactor() + }) +} + // SetPriority sets the "priority" field. func (u *AccountUpsertBulk) SetPriority(v int) *AccountUpsertBulk { return u.Update(func(s *AccountUpsert) { diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go index 875888e0..6f443c65 100644 --- a/backend/ent/account_update.go +++ b/backend/ent/account_update.go @@ -172,6 +172,33 @@ func (_u *AccountUpdate) AddConcurrency(v int) *AccountUpdate { return _u } +// SetLoadFactor sets the "load_factor" field. +func (_u *AccountUpdate) SetLoadFactor(v int) *AccountUpdate { + _u.mutation.ResetLoadFactor() + _u.mutation.SetLoadFactor(v) + return _u +} + +// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableLoadFactor(v *int) *AccountUpdate { + if v != nil { + _u.SetLoadFactor(*v) + } + return _u +} + +// AddLoadFactor adds value to the "load_factor" field. +func (_u *AccountUpdate) AddLoadFactor(v int) *AccountUpdate { + _u.mutation.AddLoadFactor(v) + return _u +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (_u *AccountUpdate) ClearLoadFactor() *AccountUpdate { + _u.mutation.ClearLoadFactor() + return _u +} + // SetPriority sets the "priority" field. func (_u *AccountUpdate) SetPriority(v int) *AccountUpdate { _u.mutation.ResetPriority() @@ -684,6 +711,15 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AddedConcurrency(); ok { _spec.AddField(account.FieldConcurrency, field.TypeInt, value) } + if value, ok := _u.mutation.LoadFactor(); ok { + _spec.SetField(account.FieldLoadFactor, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedLoadFactor(); ok { + _spec.AddField(account.FieldLoadFactor, field.TypeInt, value) + } + if _u.mutation.LoadFactorCleared() { + _spec.ClearField(account.FieldLoadFactor, field.TypeInt) + } if value, ok := _u.mutation.Priority(); ok { _spec.SetField(account.FieldPriority, field.TypeInt, value) } @@ -1063,6 +1099,33 @@ func (_u *AccountUpdateOne) AddConcurrency(v int) *AccountUpdateOne { return _u } +// SetLoadFactor sets the "load_factor" field. +func (_u *AccountUpdateOne) SetLoadFactor(v int) *AccountUpdateOne { + _u.mutation.ResetLoadFactor() + _u.mutation.SetLoadFactor(v) + return _u +} + +// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableLoadFactor(v *int) *AccountUpdateOne { + if v != nil { + _u.SetLoadFactor(*v) + } + return _u +} + +// AddLoadFactor adds value to the "load_factor" field. +func (_u *AccountUpdateOne) AddLoadFactor(v int) *AccountUpdateOne { + _u.mutation.AddLoadFactor(v) + return _u +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (_u *AccountUpdateOne) ClearLoadFactor() *AccountUpdateOne { + _u.mutation.ClearLoadFactor() + return _u +} + // SetPriority sets the "priority" field. func (_u *AccountUpdateOne) SetPriority(v int) *AccountUpdateOne { _u.mutation.ResetPriority() @@ -1605,6 +1668,15 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er if value, ok := _u.mutation.AddedConcurrency(); ok { _spec.AddField(account.FieldConcurrency, field.TypeInt, value) } + if value, ok := _u.mutation.LoadFactor(); ok { + _spec.SetField(account.FieldLoadFactor, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedLoadFactor(); ok { + _spec.AddField(account.FieldLoadFactor, field.TypeInt, value) + } + if _u.mutation.LoadFactorCleared() { + _spec.ClearField(account.FieldLoadFactor, field.TypeInt) + } if value, ok := _u.mutation.Priority(); ok { _spec.SetField(account.FieldPriority, field.TypeInt, value) } diff --git a/backend/ent/announcement.go b/backend/ent/announcement.go index 93d7a375..6c5b21da 100644 --- a/backend/ent/announcement.go +++ b/backend/ent/announcement.go @@ -25,6 +25,8 @@ type Announcement struct { Content string `json:"content,omitempty"` // 状态: draft, active, archived Status string `json:"status,omitempty"` + // 通知模式: silent(仅铃铛), popup(弹窗提醒) + NotifyMode string `json:"notify_mode,omitempty"` // 展示条件(JSON 规则) Targeting domain.AnnouncementTargeting `json:"targeting,omitempty"` // 开始展示时间(为空表示立即生效) @@ -72,7 +74,7 @@ func (*Announcement) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case announcement.FieldID, announcement.FieldCreatedBy, announcement.FieldUpdatedBy: values[i] = new(sql.NullInt64) - case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus: + case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus, announcement.FieldNotifyMode: values[i] = new(sql.NullString) case announcement.FieldStartsAt, announcement.FieldEndsAt, announcement.FieldCreatedAt, announcement.FieldUpdatedAt: values[i] = new(sql.NullTime) @@ -115,6 +117,12 @@ func (_m *Announcement) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Status = value.String } + case announcement.FieldNotifyMode: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field notify_mode", values[i]) + } else if value.Valid { + _m.NotifyMode = value.String + } case announcement.FieldTargeting: if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field targeting", values[i]) @@ -213,6 +221,9 @@ func (_m *Announcement) String() string { builder.WriteString("status=") builder.WriteString(_m.Status) builder.WriteString(", ") + builder.WriteString("notify_mode=") + builder.WriteString(_m.NotifyMode) + builder.WriteString(", ") builder.WriteString("targeting=") builder.WriteString(fmt.Sprintf("%v", _m.Targeting)) builder.WriteString(", ") diff --git a/backend/ent/announcement/announcement.go b/backend/ent/announcement/announcement.go index 4f34ee05..71ba25ff 100644 --- a/backend/ent/announcement/announcement.go +++ b/backend/ent/announcement/announcement.go @@ -20,6 +20,8 @@ const ( FieldContent = "content" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" + // FieldNotifyMode holds the string denoting the notify_mode field in the database. + FieldNotifyMode = "notify_mode" // FieldTargeting holds the string denoting the targeting field in the database. FieldTargeting = "targeting" // FieldStartsAt holds the string denoting the starts_at field in the database. @@ -53,6 +55,7 @@ var Columns = []string{ FieldTitle, FieldContent, FieldStatus, + FieldNotifyMode, FieldTargeting, FieldStartsAt, FieldEndsAt, @@ -81,6 +84,10 @@ var ( DefaultStatus string // StatusValidator is a validator for the "status" field. It is called by the builders before save. StatusValidator func(string) error + // DefaultNotifyMode holds the default value on creation for the "notify_mode" field. + DefaultNotifyMode string + // NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save. + NotifyModeValidator func(string) error // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. @@ -112,6 +119,11 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() } +// ByNotifyMode orders the results by the notify_mode field. +func ByNotifyMode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldNotifyMode, opts...).ToFunc() +} + // ByStartsAt orders the results by the starts_at field. func ByStartsAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStartsAt, opts...).ToFunc() diff --git a/backend/ent/announcement/where.go b/backend/ent/announcement/where.go index d3cad2a5..2eea5f0b 100644 --- a/backend/ent/announcement/where.go +++ b/backend/ent/announcement/where.go @@ -70,6 +70,11 @@ func Status(v string) predicate.Announcement { return predicate.Announcement(sql.FieldEQ(FieldStatus, v)) } +// NotifyMode applies equality check predicate on the "notify_mode" field. It's identical to NotifyModeEQ. +func NotifyMode(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v)) +} + // StartsAt applies equality check predicate on the "starts_at" field. It's identical to StartsAtEQ. func StartsAt(v time.Time) predicate.Announcement { return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v)) @@ -295,6 +300,71 @@ func StatusContainsFold(v string) predicate.Announcement { return predicate.Announcement(sql.FieldContainsFold(FieldStatus, v)) } +// NotifyModeEQ applies the EQ predicate on the "notify_mode" field. +func NotifyModeEQ(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v)) +} + +// NotifyModeNEQ applies the NEQ predicate on the "notify_mode" field. +func NotifyModeNEQ(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldNEQ(FieldNotifyMode, v)) +} + +// NotifyModeIn applies the In predicate on the "notify_mode" field. +func NotifyModeIn(vs ...string) predicate.Announcement { + return predicate.Announcement(sql.FieldIn(FieldNotifyMode, vs...)) +} + +// NotifyModeNotIn applies the NotIn predicate on the "notify_mode" field. +func NotifyModeNotIn(vs ...string) predicate.Announcement { + return predicate.Announcement(sql.FieldNotIn(FieldNotifyMode, vs...)) +} + +// NotifyModeGT applies the GT predicate on the "notify_mode" field. +func NotifyModeGT(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldGT(FieldNotifyMode, v)) +} + +// NotifyModeGTE applies the GTE predicate on the "notify_mode" field. +func NotifyModeGTE(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldGTE(FieldNotifyMode, v)) +} + +// NotifyModeLT applies the LT predicate on the "notify_mode" field. +func NotifyModeLT(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldLT(FieldNotifyMode, v)) +} + +// NotifyModeLTE applies the LTE predicate on the "notify_mode" field. +func NotifyModeLTE(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldLTE(FieldNotifyMode, v)) +} + +// NotifyModeContains applies the Contains predicate on the "notify_mode" field. +func NotifyModeContains(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldContains(FieldNotifyMode, v)) +} + +// NotifyModeHasPrefix applies the HasPrefix predicate on the "notify_mode" field. +func NotifyModeHasPrefix(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldHasPrefix(FieldNotifyMode, v)) +} + +// NotifyModeHasSuffix applies the HasSuffix predicate on the "notify_mode" field. +func NotifyModeHasSuffix(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldHasSuffix(FieldNotifyMode, v)) +} + +// NotifyModeEqualFold applies the EqualFold predicate on the "notify_mode" field. +func NotifyModeEqualFold(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldEqualFold(FieldNotifyMode, v)) +} + +// NotifyModeContainsFold applies the ContainsFold predicate on the "notify_mode" field. +func NotifyModeContainsFold(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldContainsFold(FieldNotifyMode, v)) +} + // TargetingIsNil applies the IsNil predicate on the "targeting" field. func TargetingIsNil() predicate.Announcement { return predicate.Announcement(sql.FieldIsNull(FieldTargeting)) diff --git a/backend/ent/announcement_create.go b/backend/ent/announcement_create.go index 151d4c11..d9029792 100644 --- a/backend/ent/announcement_create.go +++ b/backend/ent/announcement_create.go @@ -50,6 +50,20 @@ func (_c *AnnouncementCreate) SetNillableStatus(v *string) *AnnouncementCreate { return _c } +// SetNotifyMode sets the "notify_mode" field. +func (_c *AnnouncementCreate) SetNotifyMode(v string) *AnnouncementCreate { + _c.mutation.SetNotifyMode(v) + return _c +} + +// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil. +func (_c *AnnouncementCreate) SetNillableNotifyMode(v *string) *AnnouncementCreate { + if v != nil { + _c.SetNotifyMode(*v) + } + return _c +} + // SetTargeting sets the "targeting" field. func (_c *AnnouncementCreate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementCreate { _c.mutation.SetTargeting(v) @@ -202,6 +216,10 @@ func (_c *AnnouncementCreate) defaults() { v := announcement.DefaultStatus _c.mutation.SetStatus(v) } + if _, ok := _c.mutation.NotifyMode(); !ok { + v := announcement.DefaultNotifyMode + _c.mutation.SetNotifyMode(v) + } if _, ok := _c.mutation.CreatedAt(); !ok { v := announcement.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) @@ -238,6 +256,14 @@ func (_c *AnnouncementCreate) check() error { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)} } } + if _, ok := _c.mutation.NotifyMode(); !ok { + return &ValidationError{Name: "notify_mode", err: errors.New(`ent: missing required field "Announcement.notify_mode"`)} + } + if v, ok := _c.mutation.NotifyMode(); ok { + if err := announcement.NotifyModeValidator(v); err != nil { + return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)} + } + } if _, ok := _c.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Announcement.created_at"`)} } @@ -283,6 +309,10 @@ func (_c *AnnouncementCreate) createSpec() (*Announcement, *sqlgraph.CreateSpec) _spec.SetField(announcement.FieldStatus, field.TypeString, value) _node.Status = value } + if value, ok := _c.mutation.NotifyMode(); ok { + _spec.SetField(announcement.FieldNotifyMode, field.TypeString, value) + _node.NotifyMode = value + } if value, ok := _c.mutation.Targeting(); ok { _spec.SetField(announcement.FieldTargeting, field.TypeJSON, value) _node.Targeting = value @@ -415,6 +445,18 @@ func (u *AnnouncementUpsert) UpdateStatus() *AnnouncementUpsert { return u } +// SetNotifyMode sets the "notify_mode" field. +func (u *AnnouncementUpsert) SetNotifyMode(v string) *AnnouncementUpsert { + u.Set(announcement.FieldNotifyMode, v) + return u +} + +// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create. +func (u *AnnouncementUpsert) UpdateNotifyMode() *AnnouncementUpsert { + u.SetExcluded(announcement.FieldNotifyMode) + return u +} + // SetTargeting sets the "targeting" field. func (u *AnnouncementUpsert) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsert { u.Set(announcement.FieldTargeting, v) @@ -616,6 +658,20 @@ func (u *AnnouncementUpsertOne) UpdateStatus() *AnnouncementUpsertOne { }) } +// SetNotifyMode sets the "notify_mode" field. +func (u *AnnouncementUpsertOne) SetNotifyMode(v string) *AnnouncementUpsertOne { + return u.Update(func(s *AnnouncementUpsert) { + s.SetNotifyMode(v) + }) +} + +// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create. +func (u *AnnouncementUpsertOne) UpdateNotifyMode() *AnnouncementUpsertOne { + return u.Update(func(s *AnnouncementUpsert) { + s.UpdateNotifyMode() + }) +} + // SetTargeting sets the "targeting" field. func (u *AnnouncementUpsertOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertOne { return u.Update(func(s *AnnouncementUpsert) { @@ -1002,6 +1058,20 @@ func (u *AnnouncementUpsertBulk) UpdateStatus() *AnnouncementUpsertBulk { }) } +// SetNotifyMode sets the "notify_mode" field. +func (u *AnnouncementUpsertBulk) SetNotifyMode(v string) *AnnouncementUpsertBulk { + return u.Update(func(s *AnnouncementUpsert) { + s.SetNotifyMode(v) + }) +} + +// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create. +func (u *AnnouncementUpsertBulk) UpdateNotifyMode() *AnnouncementUpsertBulk { + return u.Update(func(s *AnnouncementUpsert) { + s.UpdateNotifyMode() + }) +} + // SetTargeting sets the "targeting" field. func (u *AnnouncementUpsertBulk) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertBulk { return u.Update(func(s *AnnouncementUpsert) { diff --git a/backend/ent/announcement_update.go b/backend/ent/announcement_update.go index 702d0817..f93f4f0e 100644 --- a/backend/ent/announcement_update.go +++ b/backend/ent/announcement_update.go @@ -72,6 +72,20 @@ func (_u *AnnouncementUpdate) SetNillableStatus(v *string) *AnnouncementUpdate { return _u } +// SetNotifyMode sets the "notify_mode" field. +func (_u *AnnouncementUpdate) SetNotifyMode(v string) *AnnouncementUpdate { + _u.mutation.SetNotifyMode(v) + return _u +} + +// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil. +func (_u *AnnouncementUpdate) SetNillableNotifyMode(v *string) *AnnouncementUpdate { + if v != nil { + _u.SetNotifyMode(*v) + } + return _u +} + // SetTargeting sets the "targeting" field. func (_u *AnnouncementUpdate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdate { _u.mutation.SetTargeting(v) @@ -286,6 +300,11 @@ func (_u *AnnouncementUpdate) check() error { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)} } } + if v, ok := _u.mutation.NotifyMode(); ok { + if err := announcement.NotifyModeValidator(v); err != nil { + return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)} + } + } return nil } @@ -310,6 +329,9 @@ func (_u *AnnouncementUpdate) sqlSave(ctx context.Context) (_node int, err error if value, ok := _u.mutation.Status(); ok { _spec.SetField(announcement.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.NotifyMode(); ok { + _spec.SetField(announcement.FieldNotifyMode, field.TypeString, value) + } if value, ok := _u.mutation.Targeting(); ok { _spec.SetField(announcement.FieldTargeting, field.TypeJSON, value) } @@ -456,6 +478,20 @@ func (_u *AnnouncementUpdateOne) SetNillableStatus(v *string) *AnnouncementUpdat return _u } +// SetNotifyMode sets the "notify_mode" field. +func (_u *AnnouncementUpdateOne) SetNotifyMode(v string) *AnnouncementUpdateOne { + _u.mutation.SetNotifyMode(v) + return _u +} + +// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil. +func (_u *AnnouncementUpdateOne) SetNillableNotifyMode(v *string) *AnnouncementUpdateOne { + if v != nil { + _u.SetNotifyMode(*v) + } + return _u +} + // SetTargeting sets the "targeting" field. func (_u *AnnouncementUpdateOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdateOne { _u.mutation.SetTargeting(v) @@ -683,6 +719,11 @@ func (_u *AnnouncementUpdateOne) check() error { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)} } } + if v, ok := _u.mutation.NotifyMode(); ok { + if err := announcement.NotifyModeValidator(v); err != nil { + return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)} + } + } return nil } @@ -724,6 +765,9 @@ func (_u *AnnouncementUpdateOne) sqlSave(ctx context.Context) (_node *Announceme if value, ok := _u.mutation.Status(); ok { _spec.SetField(announcement.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.NotifyMode(); ok { + _spec.SetField(announcement.FieldNotifyMode, field.TypeString, value) + } if value, ok := _u.mutation.Targeting(); ok { _spec.SetField(announcement.FieldTargeting, field.TypeJSON, value) } diff --git a/backend/ent/apikey.go b/backend/ent/apikey.go index 760851c8..9ee660c2 100644 --- a/backend/ent/apikey.go +++ b/backend/ent/apikey.go @@ -48,6 +48,24 @@ type APIKey struct { QuotaUsed float64 `json:"quota_used,omitempty"` // Expiration time for this API key (null = never expires) ExpiresAt *time.Time `json:"expires_at,omitempty"` + // Rate limit in USD per 5 hours (0 = unlimited) + RateLimit5h float64 `json:"rate_limit_5h,omitempty"` + // Rate limit in USD per day (0 = unlimited) + RateLimit1d float64 `json:"rate_limit_1d,omitempty"` + // Rate limit in USD per 7 days (0 = unlimited) + RateLimit7d float64 `json:"rate_limit_7d,omitempty"` + // Used amount in USD for the current 5h window + Usage5h float64 `json:"usage_5h,omitempty"` + // Used amount in USD for the current 1d window + Usage1d float64 `json:"usage_1d,omitempty"` + // Used amount in USD for the current 7d window + Usage7d float64 `json:"usage_7d,omitempty"` + // Start time of the current 5h rate limit window + Window5hStart *time.Time `json:"window_5h_start,omitempty"` + // Start time of the current 1d rate limit window + Window1dStart *time.Time `json:"window_1d_start,omitempty"` + // Start time of the current 7d rate limit window + Window7dStart *time.Time `json:"window_7d_start,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the APIKeyQuery when eager-loading is set. Edges APIKeyEdges `json:"edges"` @@ -105,13 +123,13 @@ func (*APIKey) scanValues(columns []string) ([]any, error) { switch columns[i] { case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist: values[i] = new([]byte) - case apikey.FieldQuota, apikey.FieldQuotaUsed: + case apikey.FieldQuota, apikey.FieldQuotaUsed, apikey.FieldRateLimit5h, apikey.FieldRateLimit1d, apikey.FieldRateLimit7d, apikey.FieldUsage5h, apikey.FieldUsage1d, apikey.FieldUsage7d: values[i] = new(sql.NullFloat64) case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID: values[i] = new(sql.NullInt64) case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: values[i] = new(sql.NullString) - case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt: + case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt, apikey.FieldWindow5hStart, apikey.FieldWindow1dStart, apikey.FieldWindow7dStart: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -226,6 +244,63 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { _m.ExpiresAt = new(time.Time) *_m.ExpiresAt = value.Time } + case apikey.FieldRateLimit5h: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field rate_limit_5h", values[i]) + } else if value.Valid { + _m.RateLimit5h = value.Float64 + } + case apikey.FieldRateLimit1d: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field rate_limit_1d", values[i]) + } else if value.Valid { + _m.RateLimit1d = value.Float64 + } + case apikey.FieldRateLimit7d: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field rate_limit_7d", values[i]) + } else if value.Valid { + _m.RateLimit7d = value.Float64 + } + case apikey.FieldUsage5h: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field usage_5h", values[i]) + } else if value.Valid { + _m.Usage5h = value.Float64 + } + case apikey.FieldUsage1d: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field usage_1d", values[i]) + } else if value.Valid { + _m.Usage1d = value.Float64 + } + case apikey.FieldUsage7d: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field usage_7d", values[i]) + } else if value.Valid { + _m.Usage7d = value.Float64 + } + case apikey.FieldWindow5hStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field window_5h_start", values[i]) + } else if value.Valid { + _m.Window5hStart = new(time.Time) + *_m.Window5hStart = value.Time + } + case apikey.FieldWindow1dStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field window_1d_start", values[i]) + } else if value.Valid { + _m.Window1dStart = new(time.Time) + *_m.Window1dStart = value.Time + } + case apikey.FieldWindow7dStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field window_7d_start", values[i]) + } else if value.Valid { + _m.Window7dStart = new(time.Time) + *_m.Window7dStart = value.Time + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -326,6 +401,39 @@ func (_m *APIKey) String() string { builder.WriteString("expires_at=") builder.WriteString(v.Format(time.ANSIC)) } + builder.WriteString(", ") + builder.WriteString("rate_limit_5h=") + builder.WriteString(fmt.Sprintf("%v", _m.RateLimit5h)) + builder.WriteString(", ") + builder.WriteString("rate_limit_1d=") + builder.WriteString(fmt.Sprintf("%v", _m.RateLimit1d)) + builder.WriteString(", ") + builder.WriteString("rate_limit_7d=") + builder.WriteString(fmt.Sprintf("%v", _m.RateLimit7d)) + builder.WriteString(", ") + builder.WriteString("usage_5h=") + builder.WriteString(fmt.Sprintf("%v", _m.Usage5h)) + builder.WriteString(", ") + builder.WriteString("usage_1d=") + builder.WriteString(fmt.Sprintf("%v", _m.Usage1d)) + builder.WriteString(", ") + builder.WriteString("usage_7d=") + builder.WriteString(fmt.Sprintf("%v", _m.Usage7d)) + builder.WriteString(", ") + if v := _m.Window5hStart; v != nil { + builder.WriteString("window_5h_start=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.Window1dStart; v != nil { + builder.WriteString("window_1d_start=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.Window7dStart; v != nil { + builder.WriteString("window_7d_start=") + builder.WriteString(v.Format(time.ANSIC)) + } builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/apikey/apikey.go b/backend/ent/apikey/apikey.go index 6abea56b..d398a027 100644 --- a/backend/ent/apikey/apikey.go +++ b/backend/ent/apikey/apikey.go @@ -43,6 +43,24 @@ const ( FieldQuotaUsed = "quota_used" // FieldExpiresAt holds the string denoting the expires_at field in the database. FieldExpiresAt = "expires_at" + // FieldRateLimit5h holds the string denoting the rate_limit_5h field in the database. + FieldRateLimit5h = "rate_limit_5h" + // FieldRateLimit1d holds the string denoting the rate_limit_1d field in the database. + FieldRateLimit1d = "rate_limit_1d" + // FieldRateLimit7d holds the string denoting the rate_limit_7d field in the database. + FieldRateLimit7d = "rate_limit_7d" + // FieldUsage5h holds the string denoting the usage_5h field in the database. + FieldUsage5h = "usage_5h" + // FieldUsage1d holds the string denoting the usage_1d field in the database. + FieldUsage1d = "usage_1d" + // FieldUsage7d holds the string denoting the usage_7d field in the database. + FieldUsage7d = "usage_7d" + // FieldWindow5hStart holds the string denoting the window_5h_start field in the database. + FieldWindow5hStart = "window_5h_start" + // FieldWindow1dStart holds the string denoting the window_1d_start field in the database. + FieldWindow1dStart = "window_1d_start" + // FieldWindow7dStart holds the string denoting the window_7d_start field in the database. + FieldWindow7dStart = "window_7d_start" // EdgeUser holds the string denoting the user edge name in mutations. EdgeUser = "user" // EdgeGroup holds the string denoting the group edge name in mutations. @@ -91,6 +109,15 @@ var Columns = []string{ FieldQuota, FieldQuotaUsed, FieldExpiresAt, + FieldRateLimit5h, + FieldRateLimit1d, + FieldRateLimit7d, + FieldUsage5h, + FieldUsage1d, + FieldUsage7d, + FieldWindow5hStart, + FieldWindow1dStart, + FieldWindow7dStart, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -129,6 +156,18 @@ var ( DefaultQuota float64 // DefaultQuotaUsed holds the default value on creation for the "quota_used" field. DefaultQuotaUsed float64 + // DefaultRateLimit5h holds the default value on creation for the "rate_limit_5h" field. + DefaultRateLimit5h float64 + // DefaultRateLimit1d holds the default value on creation for the "rate_limit_1d" field. + DefaultRateLimit1d float64 + // DefaultRateLimit7d holds the default value on creation for the "rate_limit_7d" field. + DefaultRateLimit7d float64 + // DefaultUsage5h holds the default value on creation for the "usage_5h" field. + DefaultUsage5h float64 + // DefaultUsage1d holds the default value on creation for the "usage_1d" field. + DefaultUsage1d float64 + // DefaultUsage7d holds the default value on creation for the "usage_7d" field. + DefaultUsage7d float64 ) // OrderOption defines the ordering options for the APIKey queries. @@ -199,6 +238,51 @@ func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() } +// ByRateLimit5h orders the results by the rate_limit_5h field. +func ByRateLimit5h(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRateLimit5h, opts...).ToFunc() +} + +// ByRateLimit1d orders the results by the rate_limit_1d field. +func ByRateLimit1d(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRateLimit1d, opts...).ToFunc() +} + +// ByRateLimit7d orders the results by the rate_limit_7d field. +func ByRateLimit7d(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRateLimit7d, opts...).ToFunc() +} + +// ByUsage5h orders the results by the usage_5h field. +func ByUsage5h(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUsage5h, opts...).ToFunc() +} + +// ByUsage1d orders the results by the usage_1d field. +func ByUsage1d(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUsage1d, opts...).ToFunc() +} + +// ByUsage7d orders the results by the usage_7d field. +func ByUsage7d(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUsage7d, opts...).ToFunc() +} + +// ByWindow5hStart orders the results by the window_5h_start field. +func ByWindow5hStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWindow5hStart, opts...).ToFunc() +} + +// ByWindow1dStart orders the results by the window_1d_start field. +func ByWindow1dStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWindow1dStart, opts...).ToFunc() +} + +// ByWindow7dStart orders the results by the window_7d_start field. +func ByWindow7dStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWindow7dStart, opts...).ToFunc() +} + // ByUserField orders the results by user field. func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/apikey/where.go b/backend/ent/apikey/where.go index c1900ee1..edd2652b 100644 --- a/backend/ent/apikey/where.go +++ b/backend/ent/apikey/where.go @@ -115,6 +115,51 @@ func ExpiresAt(v time.Time) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v)) } +// RateLimit5h applies equality check predicate on the "rate_limit_5h" field. It's identical to RateLimit5hEQ. +func RateLimit5h(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit5h, v)) +} + +// RateLimit1d applies equality check predicate on the "rate_limit_1d" field. It's identical to RateLimit1dEQ. +func RateLimit1d(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit1d, v)) +} + +// RateLimit7d applies equality check predicate on the "rate_limit_7d" field. It's identical to RateLimit7dEQ. +func RateLimit7d(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit7d, v)) +} + +// Usage5h applies equality check predicate on the "usage_5h" field. It's identical to Usage5hEQ. +func Usage5h(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage5h, v)) +} + +// Usage1d applies equality check predicate on the "usage_1d" field. It's identical to Usage1dEQ. +func Usage1d(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage1d, v)) +} + +// Usage7d applies equality check predicate on the "usage_7d" field. It's identical to Usage7dEQ. +func Usage7d(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage7d, v)) +} + +// Window5hStart applies equality check predicate on the "window_5h_start" field. It's identical to Window5hStartEQ. +func Window5hStart(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow5hStart, v)) +} + +// Window1dStart applies equality check predicate on the "window_1d_start" field. It's identical to Window1dStartEQ. +func Window1dStart(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow1dStart, v)) +} + +// Window7dStart applies equality check predicate on the "window_7d_start" field. It's identical to Window7dStartEQ. +func Window7dStart(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow7dStart, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldCreatedAt, v)) @@ -690,6 +735,396 @@ func ExpiresAtNotNil() predicate.APIKey { return predicate.APIKey(sql.FieldNotNull(FieldExpiresAt)) } +// RateLimit5hEQ applies the EQ predicate on the "rate_limit_5h" field. +func RateLimit5hEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit5h, v)) +} + +// RateLimit5hNEQ applies the NEQ predicate on the "rate_limit_5h" field. +func RateLimit5hNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldRateLimit5h, v)) +} + +// RateLimit5hIn applies the In predicate on the "rate_limit_5h" field. +func RateLimit5hIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldRateLimit5h, vs...)) +} + +// RateLimit5hNotIn applies the NotIn predicate on the "rate_limit_5h" field. +func RateLimit5hNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldRateLimit5h, vs...)) +} + +// RateLimit5hGT applies the GT predicate on the "rate_limit_5h" field. +func RateLimit5hGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldRateLimit5h, v)) +} + +// RateLimit5hGTE applies the GTE predicate on the "rate_limit_5h" field. +func RateLimit5hGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldRateLimit5h, v)) +} + +// RateLimit5hLT applies the LT predicate on the "rate_limit_5h" field. +func RateLimit5hLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldRateLimit5h, v)) +} + +// RateLimit5hLTE applies the LTE predicate on the "rate_limit_5h" field. +func RateLimit5hLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldRateLimit5h, v)) +} + +// RateLimit1dEQ applies the EQ predicate on the "rate_limit_1d" field. +func RateLimit1dEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit1d, v)) +} + +// RateLimit1dNEQ applies the NEQ predicate on the "rate_limit_1d" field. +func RateLimit1dNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldRateLimit1d, v)) +} + +// RateLimit1dIn applies the In predicate on the "rate_limit_1d" field. +func RateLimit1dIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldRateLimit1d, vs...)) +} + +// RateLimit1dNotIn applies the NotIn predicate on the "rate_limit_1d" field. +func RateLimit1dNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldRateLimit1d, vs...)) +} + +// RateLimit1dGT applies the GT predicate on the "rate_limit_1d" field. +func RateLimit1dGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldRateLimit1d, v)) +} + +// RateLimit1dGTE applies the GTE predicate on the "rate_limit_1d" field. +func RateLimit1dGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldRateLimit1d, v)) +} + +// RateLimit1dLT applies the LT predicate on the "rate_limit_1d" field. +func RateLimit1dLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldRateLimit1d, v)) +} + +// RateLimit1dLTE applies the LTE predicate on the "rate_limit_1d" field. +func RateLimit1dLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldRateLimit1d, v)) +} + +// RateLimit7dEQ applies the EQ predicate on the "rate_limit_7d" field. +func RateLimit7dEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit7d, v)) +} + +// RateLimit7dNEQ applies the NEQ predicate on the "rate_limit_7d" field. +func RateLimit7dNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldRateLimit7d, v)) +} + +// RateLimit7dIn applies the In predicate on the "rate_limit_7d" field. +func RateLimit7dIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldRateLimit7d, vs...)) +} + +// RateLimit7dNotIn applies the NotIn predicate on the "rate_limit_7d" field. +func RateLimit7dNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldRateLimit7d, vs...)) +} + +// RateLimit7dGT applies the GT predicate on the "rate_limit_7d" field. +func RateLimit7dGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldRateLimit7d, v)) +} + +// RateLimit7dGTE applies the GTE predicate on the "rate_limit_7d" field. +func RateLimit7dGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldRateLimit7d, v)) +} + +// RateLimit7dLT applies the LT predicate on the "rate_limit_7d" field. +func RateLimit7dLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldRateLimit7d, v)) +} + +// RateLimit7dLTE applies the LTE predicate on the "rate_limit_7d" field. +func RateLimit7dLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldRateLimit7d, v)) +} + +// Usage5hEQ applies the EQ predicate on the "usage_5h" field. +func Usage5hEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage5h, v)) +} + +// Usage5hNEQ applies the NEQ predicate on the "usage_5h" field. +func Usage5hNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldUsage5h, v)) +} + +// Usage5hIn applies the In predicate on the "usage_5h" field. +func Usage5hIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldUsage5h, vs...)) +} + +// Usage5hNotIn applies the NotIn predicate on the "usage_5h" field. +func Usage5hNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldUsage5h, vs...)) +} + +// Usage5hGT applies the GT predicate on the "usage_5h" field. +func Usage5hGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldUsage5h, v)) +} + +// Usage5hGTE applies the GTE predicate on the "usage_5h" field. +func Usage5hGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldUsage5h, v)) +} + +// Usage5hLT applies the LT predicate on the "usage_5h" field. +func Usage5hLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldUsage5h, v)) +} + +// Usage5hLTE applies the LTE predicate on the "usage_5h" field. +func Usage5hLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldUsage5h, v)) +} + +// Usage1dEQ applies the EQ predicate on the "usage_1d" field. +func Usage1dEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage1d, v)) +} + +// Usage1dNEQ applies the NEQ predicate on the "usage_1d" field. +func Usage1dNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldUsage1d, v)) +} + +// Usage1dIn applies the In predicate on the "usage_1d" field. +func Usage1dIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldUsage1d, vs...)) +} + +// Usage1dNotIn applies the NotIn predicate on the "usage_1d" field. +func Usage1dNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldUsage1d, vs...)) +} + +// Usage1dGT applies the GT predicate on the "usage_1d" field. +func Usage1dGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldUsage1d, v)) +} + +// Usage1dGTE applies the GTE predicate on the "usage_1d" field. +func Usage1dGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldUsage1d, v)) +} + +// Usage1dLT applies the LT predicate on the "usage_1d" field. +func Usage1dLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldUsage1d, v)) +} + +// Usage1dLTE applies the LTE predicate on the "usage_1d" field. +func Usage1dLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldUsage1d, v)) +} + +// Usage7dEQ applies the EQ predicate on the "usage_7d" field. +func Usage7dEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage7d, v)) +} + +// Usage7dNEQ applies the NEQ predicate on the "usage_7d" field. +func Usage7dNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldUsage7d, v)) +} + +// Usage7dIn applies the In predicate on the "usage_7d" field. +func Usage7dIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldUsage7d, vs...)) +} + +// Usage7dNotIn applies the NotIn predicate on the "usage_7d" field. +func Usage7dNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldUsage7d, vs...)) +} + +// Usage7dGT applies the GT predicate on the "usage_7d" field. +func Usage7dGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldUsage7d, v)) +} + +// Usage7dGTE applies the GTE predicate on the "usage_7d" field. +func Usage7dGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldUsage7d, v)) +} + +// Usage7dLT applies the LT predicate on the "usage_7d" field. +func Usage7dLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldUsage7d, v)) +} + +// Usage7dLTE applies the LTE predicate on the "usage_7d" field. +func Usage7dLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldUsage7d, v)) +} + +// Window5hStartEQ applies the EQ predicate on the "window_5h_start" field. +func Window5hStartEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow5hStart, v)) +} + +// Window5hStartNEQ applies the NEQ predicate on the "window_5h_start" field. +func Window5hStartNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldWindow5hStart, v)) +} + +// Window5hStartIn applies the In predicate on the "window_5h_start" field. +func Window5hStartIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldWindow5hStart, vs...)) +} + +// Window5hStartNotIn applies the NotIn predicate on the "window_5h_start" field. +func Window5hStartNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldWindow5hStart, vs...)) +} + +// Window5hStartGT applies the GT predicate on the "window_5h_start" field. +func Window5hStartGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldWindow5hStart, v)) +} + +// Window5hStartGTE applies the GTE predicate on the "window_5h_start" field. +func Window5hStartGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldWindow5hStart, v)) +} + +// Window5hStartLT applies the LT predicate on the "window_5h_start" field. +func Window5hStartLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldWindow5hStart, v)) +} + +// Window5hStartLTE applies the LTE predicate on the "window_5h_start" field. +func Window5hStartLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldWindow5hStart, v)) +} + +// Window5hStartIsNil applies the IsNil predicate on the "window_5h_start" field. +func Window5hStartIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldWindow5hStart)) +} + +// Window5hStartNotNil applies the NotNil predicate on the "window_5h_start" field. +func Window5hStartNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldWindow5hStart)) +} + +// Window1dStartEQ applies the EQ predicate on the "window_1d_start" field. +func Window1dStartEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow1dStart, v)) +} + +// Window1dStartNEQ applies the NEQ predicate on the "window_1d_start" field. +func Window1dStartNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldWindow1dStart, v)) +} + +// Window1dStartIn applies the In predicate on the "window_1d_start" field. +func Window1dStartIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldWindow1dStart, vs...)) +} + +// Window1dStartNotIn applies the NotIn predicate on the "window_1d_start" field. +func Window1dStartNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldWindow1dStart, vs...)) +} + +// Window1dStartGT applies the GT predicate on the "window_1d_start" field. +func Window1dStartGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldWindow1dStart, v)) +} + +// Window1dStartGTE applies the GTE predicate on the "window_1d_start" field. +func Window1dStartGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldWindow1dStart, v)) +} + +// Window1dStartLT applies the LT predicate on the "window_1d_start" field. +func Window1dStartLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldWindow1dStart, v)) +} + +// Window1dStartLTE applies the LTE predicate on the "window_1d_start" field. +func Window1dStartLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldWindow1dStart, v)) +} + +// Window1dStartIsNil applies the IsNil predicate on the "window_1d_start" field. +func Window1dStartIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldWindow1dStart)) +} + +// Window1dStartNotNil applies the NotNil predicate on the "window_1d_start" field. +func Window1dStartNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldWindow1dStart)) +} + +// Window7dStartEQ applies the EQ predicate on the "window_7d_start" field. +func Window7dStartEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow7dStart, v)) +} + +// Window7dStartNEQ applies the NEQ predicate on the "window_7d_start" field. +func Window7dStartNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldWindow7dStart, v)) +} + +// Window7dStartIn applies the In predicate on the "window_7d_start" field. +func Window7dStartIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldWindow7dStart, vs...)) +} + +// Window7dStartNotIn applies the NotIn predicate on the "window_7d_start" field. +func Window7dStartNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldWindow7dStart, vs...)) +} + +// Window7dStartGT applies the GT predicate on the "window_7d_start" field. +func Window7dStartGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldWindow7dStart, v)) +} + +// Window7dStartGTE applies the GTE predicate on the "window_7d_start" field. +func Window7dStartGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldWindow7dStart, v)) +} + +// Window7dStartLT applies the LT predicate on the "window_7d_start" field. +func Window7dStartLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldWindow7dStart, v)) +} + +// Window7dStartLTE applies the LTE predicate on the "window_7d_start" field. +func Window7dStartLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldWindow7dStart, v)) +} + +// Window7dStartIsNil applies the IsNil predicate on the "window_7d_start" field. +func Window7dStartIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldWindow7dStart)) +} + +// Window7dStartNotNil applies the NotNil predicate on the "window_7d_start" field. +func Window7dStartNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldWindow7dStart)) +} + // HasUser applies the HasEdge predicate on the "user" edge. func HasUser() predicate.APIKey { return predicate.APIKey(func(s *sql.Selector) { diff --git a/backend/ent/apikey_create.go b/backend/ent/apikey_create.go index bc506585..4ec8aeaa 100644 --- a/backend/ent/apikey_create.go +++ b/backend/ent/apikey_create.go @@ -181,6 +181,132 @@ func (_c *APIKeyCreate) SetNillableExpiresAt(v *time.Time) *APIKeyCreate { return _c } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (_c *APIKeyCreate) SetRateLimit5h(v float64) *APIKeyCreate { + _c.mutation.SetRateLimit5h(v) + return _c +} + +// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableRateLimit5h(v *float64) *APIKeyCreate { + if v != nil { + _c.SetRateLimit5h(*v) + } + return _c +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (_c *APIKeyCreate) SetRateLimit1d(v float64) *APIKeyCreate { + _c.mutation.SetRateLimit1d(v) + return _c +} + +// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableRateLimit1d(v *float64) *APIKeyCreate { + if v != nil { + _c.SetRateLimit1d(*v) + } + return _c +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (_c *APIKeyCreate) SetRateLimit7d(v float64) *APIKeyCreate { + _c.mutation.SetRateLimit7d(v) + return _c +} + +// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableRateLimit7d(v *float64) *APIKeyCreate { + if v != nil { + _c.SetRateLimit7d(*v) + } + return _c +} + +// SetUsage5h sets the "usage_5h" field. +func (_c *APIKeyCreate) SetUsage5h(v float64) *APIKeyCreate { + _c.mutation.SetUsage5h(v) + return _c +} + +// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableUsage5h(v *float64) *APIKeyCreate { + if v != nil { + _c.SetUsage5h(*v) + } + return _c +} + +// SetUsage1d sets the "usage_1d" field. +func (_c *APIKeyCreate) SetUsage1d(v float64) *APIKeyCreate { + _c.mutation.SetUsage1d(v) + return _c +} + +// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableUsage1d(v *float64) *APIKeyCreate { + if v != nil { + _c.SetUsage1d(*v) + } + return _c +} + +// SetUsage7d sets the "usage_7d" field. +func (_c *APIKeyCreate) SetUsage7d(v float64) *APIKeyCreate { + _c.mutation.SetUsage7d(v) + return _c +} + +// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableUsage7d(v *float64) *APIKeyCreate { + if v != nil { + _c.SetUsage7d(*v) + } + return _c +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (_c *APIKeyCreate) SetWindow5hStart(v time.Time) *APIKeyCreate { + _c.mutation.SetWindow5hStart(v) + return _c +} + +// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableWindow5hStart(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetWindow5hStart(*v) + } + return _c +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (_c *APIKeyCreate) SetWindow1dStart(v time.Time) *APIKeyCreate { + _c.mutation.SetWindow1dStart(v) + return _c +} + +// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableWindow1dStart(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetWindow1dStart(*v) + } + return _c +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (_c *APIKeyCreate) SetWindow7dStart(v time.Time) *APIKeyCreate { + _c.mutation.SetWindow7dStart(v) + return _c +} + +// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableWindow7dStart(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetWindow7dStart(*v) + } + return _c +} + // SetUser sets the "user" edge to the User entity. func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate { return _c.SetUserID(v.ID) @@ -269,6 +395,30 @@ func (_c *APIKeyCreate) defaults() error { v := apikey.DefaultQuotaUsed _c.mutation.SetQuotaUsed(v) } + if _, ok := _c.mutation.RateLimit5h(); !ok { + v := apikey.DefaultRateLimit5h + _c.mutation.SetRateLimit5h(v) + } + if _, ok := _c.mutation.RateLimit1d(); !ok { + v := apikey.DefaultRateLimit1d + _c.mutation.SetRateLimit1d(v) + } + if _, ok := _c.mutation.RateLimit7d(); !ok { + v := apikey.DefaultRateLimit7d + _c.mutation.SetRateLimit7d(v) + } + if _, ok := _c.mutation.Usage5h(); !ok { + v := apikey.DefaultUsage5h + _c.mutation.SetUsage5h(v) + } + if _, ok := _c.mutation.Usage1d(); !ok { + v := apikey.DefaultUsage1d + _c.mutation.SetUsage1d(v) + } + if _, ok := _c.mutation.Usage7d(); !ok { + v := apikey.DefaultUsage7d + _c.mutation.SetUsage7d(v) + } return nil } @@ -313,6 +463,24 @@ func (_c *APIKeyCreate) check() error { if _, ok := _c.mutation.QuotaUsed(); !ok { return &ValidationError{Name: "quota_used", err: errors.New(`ent: missing required field "APIKey.quota_used"`)} } + if _, ok := _c.mutation.RateLimit5h(); !ok { + return &ValidationError{Name: "rate_limit_5h", err: errors.New(`ent: missing required field "APIKey.rate_limit_5h"`)} + } + if _, ok := _c.mutation.RateLimit1d(); !ok { + return &ValidationError{Name: "rate_limit_1d", err: errors.New(`ent: missing required field "APIKey.rate_limit_1d"`)} + } + if _, ok := _c.mutation.RateLimit7d(); !ok { + return &ValidationError{Name: "rate_limit_7d", err: errors.New(`ent: missing required field "APIKey.rate_limit_7d"`)} + } + if _, ok := _c.mutation.Usage5h(); !ok { + return &ValidationError{Name: "usage_5h", err: errors.New(`ent: missing required field "APIKey.usage_5h"`)} + } + if _, ok := _c.mutation.Usage1d(); !ok { + return &ValidationError{Name: "usage_1d", err: errors.New(`ent: missing required field "APIKey.usage_1d"`)} + } + if _, ok := _c.mutation.Usage7d(); !ok { + return &ValidationError{Name: "usage_7d", err: errors.New(`ent: missing required field "APIKey.usage_7d"`)} + } if len(_c.mutation.UserIDs()) == 0 { return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "APIKey.user"`)} } @@ -391,6 +559,42 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) _node.ExpiresAt = &value } + if value, ok := _c.mutation.RateLimit5h(); ok { + _spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + _node.RateLimit5h = value + } + if value, ok := _c.mutation.RateLimit1d(); ok { + _spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + _node.RateLimit1d = value + } + if value, ok := _c.mutation.RateLimit7d(); ok { + _spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + _node.RateLimit7d = value + } + if value, ok := _c.mutation.Usage5h(); ok { + _spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value) + _node.Usage5h = value + } + if value, ok := _c.mutation.Usage1d(); ok { + _spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value) + _node.Usage1d = value + } + if value, ok := _c.mutation.Usage7d(); ok { + _spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value) + _node.Usage7d = value + } + if value, ok := _c.mutation.Window5hStart(); ok { + _spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value) + _node.Window5hStart = &value + } + if value, ok := _c.mutation.Window1dStart(); ok { + _spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value) + _node.Window1dStart = &value + } + if value, ok := _c.mutation.Window7dStart(); ok { + _spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value) + _node.Window7dStart = &value + } if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -697,6 +901,168 @@ func (u *APIKeyUpsert) ClearExpiresAt() *APIKeyUpsert { return u } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (u *APIKeyUpsert) SetRateLimit5h(v float64) *APIKeyUpsert { + u.Set(apikey.FieldRateLimit5h, v) + return u +} + +// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateRateLimit5h() *APIKeyUpsert { + u.SetExcluded(apikey.FieldRateLimit5h) + return u +} + +// AddRateLimit5h adds v to the "rate_limit_5h" field. +func (u *APIKeyUpsert) AddRateLimit5h(v float64) *APIKeyUpsert { + u.Add(apikey.FieldRateLimit5h, v) + return u +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (u *APIKeyUpsert) SetRateLimit1d(v float64) *APIKeyUpsert { + u.Set(apikey.FieldRateLimit1d, v) + return u +} + +// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateRateLimit1d() *APIKeyUpsert { + u.SetExcluded(apikey.FieldRateLimit1d) + return u +} + +// AddRateLimit1d adds v to the "rate_limit_1d" field. +func (u *APIKeyUpsert) AddRateLimit1d(v float64) *APIKeyUpsert { + u.Add(apikey.FieldRateLimit1d, v) + return u +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (u *APIKeyUpsert) SetRateLimit7d(v float64) *APIKeyUpsert { + u.Set(apikey.FieldRateLimit7d, v) + return u +} + +// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateRateLimit7d() *APIKeyUpsert { + u.SetExcluded(apikey.FieldRateLimit7d) + return u +} + +// AddRateLimit7d adds v to the "rate_limit_7d" field. +func (u *APIKeyUpsert) AddRateLimit7d(v float64) *APIKeyUpsert { + u.Add(apikey.FieldRateLimit7d, v) + return u +} + +// SetUsage5h sets the "usage_5h" field. +func (u *APIKeyUpsert) SetUsage5h(v float64) *APIKeyUpsert { + u.Set(apikey.FieldUsage5h, v) + return u +} + +// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateUsage5h() *APIKeyUpsert { + u.SetExcluded(apikey.FieldUsage5h) + return u +} + +// AddUsage5h adds v to the "usage_5h" field. +func (u *APIKeyUpsert) AddUsage5h(v float64) *APIKeyUpsert { + u.Add(apikey.FieldUsage5h, v) + return u +} + +// SetUsage1d sets the "usage_1d" field. +func (u *APIKeyUpsert) SetUsage1d(v float64) *APIKeyUpsert { + u.Set(apikey.FieldUsage1d, v) + return u +} + +// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateUsage1d() *APIKeyUpsert { + u.SetExcluded(apikey.FieldUsage1d) + return u +} + +// AddUsage1d adds v to the "usage_1d" field. +func (u *APIKeyUpsert) AddUsage1d(v float64) *APIKeyUpsert { + u.Add(apikey.FieldUsage1d, v) + return u +} + +// SetUsage7d sets the "usage_7d" field. +func (u *APIKeyUpsert) SetUsage7d(v float64) *APIKeyUpsert { + u.Set(apikey.FieldUsage7d, v) + return u +} + +// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateUsage7d() *APIKeyUpsert { + u.SetExcluded(apikey.FieldUsage7d) + return u +} + +// AddUsage7d adds v to the "usage_7d" field. +func (u *APIKeyUpsert) AddUsage7d(v float64) *APIKeyUpsert { + u.Add(apikey.FieldUsage7d, v) + return u +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (u *APIKeyUpsert) SetWindow5hStart(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldWindow5hStart, v) + return u +} + +// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateWindow5hStart() *APIKeyUpsert { + u.SetExcluded(apikey.FieldWindow5hStart) + return u +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (u *APIKeyUpsert) ClearWindow5hStart() *APIKeyUpsert { + u.SetNull(apikey.FieldWindow5hStart) + return u +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (u *APIKeyUpsert) SetWindow1dStart(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldWindow1dStart, v) + return u +} + +// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateWindow1dStart() *APIKeyUpsert { + u.SetExcluded(apikey.FieldWindow1dStart) + return u +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (u *APIKeyUpsert) ClearWindow1dStart() *APIKeyUpsert { + u.SetNull(apikey.FieldWindow1dStart) + return u +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (u *APIKeyUpsert) SetWindow7dStart(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldWindow7dStart, v) + return u +} + +// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateWindow7dStart() *APIKeyUpsert { + u.SetExcluded(apikey.FieldWindow7dStart) + return u +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (u *APIKeyUpsert) ClearWindow7dStart() *APIKeyUpsert { + u.SetNull(apikey.FieldWindow7dStart) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -980,6 +1346,195 @@ func (u *APIKeyUpsertOne) ClearExpiresAt() *APIKeyUpsertOne { }) } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (u *APIKeyUpsertOne) SetRateLimit5h(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit5h(v) + }) +} + +// AddRateLimit5h adds v to the "rate_limit_5h" field. +func (u *APIKeyUpsertOne) AddRateLimit5h(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit5h(v) + }) +} + +// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateRateLimit5h() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit5h() + }) +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (u *APIKeyUpsertOne) SetRateLimit1d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit1d(v) + }) +} + +// AddRateLimit1d adds v to the "rate_limit_1d" field. +func (u *APIKeyUpsertOne) AddRateLimit1d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit1d(v) + }) +} + +// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateRateLimit1d() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit1d() + }) +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (u *APIKeyUpsertOne) SetRateLimit7d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit7d(v) + }) +} + +// AddRateLimit7d adds v to the "rate_limit_7d" field. +func (u *APIKeyUpsertOne) AddRateLimit7d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit7d(v) + }) +} + +// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateRateLimit7d() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit7d() + }) +} + +// SetUsage5h sets the "usage_5h" field. +func (u *APIKeyUpsertOne) SetUsage5h(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage5h(v) + }) +} + +// AddUsage5h adds v to the "usage_5h" field. +func (u *APIKeyUpsertOne) AddUsage5h(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage5h(v) + }) +} + +// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateUsage5h() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage5h() + }) +} + +// SetUsage1d sets the "usage_1d" field. +func (u *APIKeyUpsertOne) SetUsage1d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage1d(v) + }) +} + +// AddUsage1d adds v to the "usage_1d" field. +func (u *APIKeyUpsertOne) AddUsage1d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage1d(v) + }) +} + +// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateUsage1d() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage1d() + }) +} + +// SetUsage7d sets the "usage_7d" field. +func (u *APIKeyUpsertOne) SetUsage7d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage7d(v) + }) +} + +// AddUsage7d adds v to the "usage_7d" field. +func (u *APIKeyUpsertOne) AddUsage7d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage7d(v) + }) +} + +// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateUsage7d() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage7d() + }) +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (u *APIKeyUpsertOne) SetWindow5hStart(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow5hStart(v) + }) +} + +// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateWindow5hStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow5hStart() + }) +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (u *APIKeyUpsertOne) ClearWindow5hStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow5hStart() + }) +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (u *APIKeyUpsertOne) SetWindow1dStart(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow1dStart(v) + }) +} + +// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateWindow1dStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow1dStart() + }) +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (u *APIKeyUpsertOne) ClearWindow1dStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow1dStart() + }) +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (u *APIKeyUpsertOne) SetWindow7dStart(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow7dStart(v) + }) +} + +// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateWindow7dStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow7dStart() + }) +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (u *APIKeyUpsertOne) ClearWindow7dStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow7dStart() + }) +} + // Exec executes the query. func (u *APIKeyUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1429,6 +1984,195 @@ func (u *APIKeyUpsertBulk) ClearExpiresAt() *APIKeyUpsertBulk { }) } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (u *APIKeyUpsertBulk) SetRateLimit5h(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit5h(v) + }) +} + +// AddRateLimit5h adds v to the "rate_limit_5h" field. +func (u *APIKeyUpsertBulk) AddRateLimit5h(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit5h(v) + }) +} + +// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateRateLimit5h() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit5h() + }) +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (u *APIKeyUpsertBulk) SetRateLimit1d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit1d(v) + }) +} + +// AddRateLimit1d adds v to the "rate_limit_1d" field. +func (u *APIKeyUpsertBulk) AddRateLimit1d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit1d(v) + }) +} + +// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateRateLimit1d() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit1d() + }) +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (u *APIKeyUpsertBulk) SetRateLimit7d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit7d(v) + }) +} + +// AddRateLimit7d adds v to the "rate_limit_7d" field. +func (u *APIKeyUpsertBulk) AddRateLimit7d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit7d(v) + }) +} + +// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateRateLimit7d() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit7d() + }) +} + +// SetUsage5h sets the "usage_5h" field. +func (u *APIKeyUpsertBulk) SetUsage5h(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage5h(v) + }) +} + +// AddUsage5h adds v to the "usage_5h" field. +func (u *APIKeyUpsertBulk) AddUsage5h(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage5h(v) + }) +} + +// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateUsage5h() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage5h() + }) +} + +// SetUsage1d sets the "usage_1d" field. +func (u *APIKeyUpsertBulk) SetUsage1d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage1d(v) + }) +} + +// AddUsage1d adds v to the "usage_1d" field. +func (u *APIKeyUpsertBulk) AddUsage1d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage1d(v) + }) +} + +// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateUsage1d() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage1d() + }) +} + +// SetUsage7d sets the "usage_7d" field. +func (u *APIKeyUpsertBulk) SetUsage7d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage7d(v) + }) +} + +// AddUsage7d adds v to the "usage_7d" field. +func (u *APIKeyUpsertBulk) AddUsage7d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage7d(v) + }) +} + +// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateUsage7d() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage7d() + }) +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (u *APIKeyUpsertBulk) SetWindow5hStart(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow5hStart(v) + }) +} + +// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateWindow5hStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow5hStart() + }) +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (u *APIKeyUpsertBulk) ClearWindow5hStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow5hStart() + }) +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (u *APIKeyUpsertBulk) SetWindow1dStart(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow1dStart(v) + }) +} + +// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateWindow1dStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow1dStart() + }) +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (u *APIKeyUpsertBulk) ClearWindow1dStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow1dStart() + }) +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (u *APIKeyUpsertBulk) SetWindow7dStart(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow7dStart(v) + }) +} + +// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateWindow7dStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow7dStart() + }) +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (u *APIKeyUpsertBulk) ClearWindow7dStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow7dStart() + }) +} + // Exec executes the query. func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/apikey_update.go b/backend/ent/apikey_update.go index 6ca01854..db341e4c 100644 --- a/backend/ent/apikey_update.go +++ b/backend/ent/apikey_update.go @@ -252,6 +252,192 @@ func (_u *APIKeyUpdate) ClearExpiresAt() *APIKeyUpdate { return _u } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (_u *APIKeyUpdate) SetRateLimit5h(v float64) *APIKeyUpdate { + _u.mutation.ResetRateLimit5h() + _u.mutation.SetRateLimit5h(v) + return _u +} + +// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableRateLimit5h(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetRateLimit5h(*v) + } + return _u +} + +// AddRateLimit5h adds value to the "rate_limit_5h" field. +func (_u *APIKeyUpdate) AddRateLimit5h(v float64) *APIKeyUpdate { + _u.mutation.AddRateLimit5h(v) + return _u +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (_u *APIKeyUpdate) SetRateLimit1d(v float64) *APIKeyUpdate { + _u.mutation.ResetRateLimit1d() + _u.mutation.SetRateLimit1d(v) + return _u +} + +// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableRateLimit1d(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetRateLimit1d(*v) + } + return _u +} + +// AddRateLimit1d adds value to the "rate_limit_1d" field. +func (_u *APIKeyUpdate) AddRateLimit1d(v float64) *APIKeyUpdate { + _u.mutation.AddRateLimit1d(v) + return _u +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (_u *APIKeyUpdate) SetRateLimit7d(v float64) *APIKeyUpdate { + _u.mutation.ResetRateLimit7d() + _u.mutation.SetRateLimit7d(v) + return _u +} + +// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableRateLimit7d(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetRateLimit7d(*v) + } + return _u +} + +// AddRateLimit7d adds value to the "rate_limit_7d" field. +func (_u *APIKeyUpdate) AddRateLimit7d(v float64) *APIKeyUpdate { + _u.mutation.AddRateLimit7d(v) + return _u +} + +// SetUsage5h sets the "usage_5h" field. +func (_u *APIKeyUpdate) SetUsage5h(v float64) *APIKeyUpdate { + _u.mutation.ResetUsage5h() + _u.mutation.SetUsage5h(v) + return _u +} + +// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableUsage5h(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetUsage5h(*v) + } + return _u +} + +// AddUsage5h adds value to the "usage_5h" field. +func (_u *APIKeyUpdate) AddUsage5h(v float64) *APIKeyUpdate { + _u.mutation.AddUsage5h(v) + return _u +} + +// SetUsage1d sets the "usage_1d" field. +func (_u *APIKeyUpdate) SetUsage1d(v float64) *APIKeyUpdate { + _u.mutation.ResetUsage1d() + _u.mutation.SetUsage1d(v) + return _u +} + +// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableUsage1d(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetUsage1d(*v) + } + return _u +} + +// AddUsage1d adds value to the "usage_1d" field. +func (_u *APIKeyUpdate) AddUsage1d(v float64) *APIKeyUpdate { + _u.mutation.AddUsage1d(v) + return _u +} + +// SetUsage7d sets the "usage_7d" field. +func (_u *APIKeyUpdate) SetUsage7d(v float64) *APIKeyUpdate { + _u.mutation.ResetUsage7d() + _u.mutation.SetUsage7d(v) + return _u +} + +// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableUsage7d(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetUsage7d(*v) + } + return _u +} + +// AddUsage7d adds value to the "usage_7d" field. +func (_u *APIKeyUpdate) AddUsage7d(v float64) *APIKeyUpdate { + _u.mutation.AddUsage7d(v) + return _u +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (_u *APIKeyUpdate) SetWindow5hStart(v time.Time) *APIKeyUpdate { + _u.mutation.SetWindow5hStart(v) + return _u +} + +// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableWindow5hStart(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetWindow5hStart(*v) + } + return _u +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (_u *APIKeyUpdate) ClearWindow5hStart() *APIKeyUpdate { + _u.mutation.ClearWindow5hStart() + return _u +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (_u *APIKeyUpdate) SetWindow1dStart(v time.Time) *APIKeyUpdate { + _u.mutation.SetWindow1dStart(v) + return _u +} + +// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableWindow1dStart(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetWindow1dStart(*v) + } + return _u +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (_u *APIKeyUpdate) ClearWindow1dStart() *APIKeyUpdate { + _u.mutation.ClearWindow1dStart() + return _u +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (_u *APIKeyUpdate) SetWindow7dStart(v time.Time) *APIKeyUpdate { + _u.mutation.SetWindow7dStart(v) + return _u +} + +// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableWindow7dStart(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetWindow7dStart(*v) + } + return _u +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (_u *APIKeyUpdate) ClearWindow7dStart() *APIKeyUpdate { + _u.mutation.ClearWindow7dStart() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate { return _u.SetUserID(v.ID) @@ -456,6 +642,60 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ExpiresAtCleared() { _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) } + if value, ok := _u.mutation.RateLimit5h(); ok { + _spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit5h(); ok { + _spec.AddField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateLimit1d(); ok { + _spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit1d(); ok { + _spec.AddField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateLimit7d(); ok { + _spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit7d(); ok { + _spec.AddField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage5h(); ok { + _spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage5h(); ok { + _spec.AddField(apikey.FieldUsage5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage1d(); ok { + _spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage1d(); ok { + _spec.AddField(apikey.FieldUsage1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage7d(); ok { + _spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage7d(); ok { + _spec.AddField(apikey.FieldUsage7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Window5hStart(); ok { + _spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value) + } + if _u.mutation.Window5hStartCleared() { + _spec.ClearField(apikey.FieldWindow5hStart, field.TypeTime) + } + if value, ok := _u.mutation.Window1dStart(); ok { + _spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value) + } + if _u.mutation.Window1dStartCleared() { + _spec.ClearField(apikey.FieldWindow1dStart, field.TypeTime) + } + if value, ok := _u.mutation.Window7dStart(); ok { + _spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value) + } + if _u.mutation.Window7dStartCleared() { + _spec.ClearField(apikey.FieldWindow7dStart, field.TypeTime) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -799,6 +1039,192 @@ func (_u *APIKeyUpdateOne) ClearExpiresAt() *APIKeyUpdateOne { return _u } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (_u *APIKeyUpdateOne) SetRateLimit5h(v float64) *APIKeyUpdateOne { + _u.mutation.ResetRateLimit5h() + _u.mutation.SetRateLimit5h(v) + return _u +} + +// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableRateLimit5h(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetRateLimit5h(*v) + } + return _u +} + +// AddRateLimit5h adds value to the "rate_limit_5h" field. +func (_u *APIKeyUpdateOne) AddRateLimit5h(v float64) *APIKeyUpdateOne { + _u.mutation.AddRateLimit5h(v) + return _u +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (_u *APIKeyUpdateOne) SetRateLimit1d(v float64) *APIKeyUpdateOne { + _u.mutation.ResetRateLimit1d() + _u.mutation.SetRateLimit1d(v) + return _u +} + +// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableRateLimit1d(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetRateLimit1d(*v) + } + return _u +} + +// AddRateLimit1d adds value to the "rate_limit_1d" field. +func (_u *APIKeyUpdateOne) AddRateLimit1d(v float64) *APIKeyUpdateOne { + _u.mutation.AddRateLimit1d(v) + return _u +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (_u *APIKeyUpdateOne) SetRateLimit7d(v float64) *APIKeyUpdateOne { + _u.mutation.ResetRateLimit7d() + _u.mutation.SetRateLimit7d(v) + return _u +} + +// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableRateLimit7d(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetRateLimit7d(*v) + } + return _u +} + +// AddRateLimit7d adds value to the "rate_limit_7d" field. +func (_u *APIKeyUpdateOne) AddRateLimit7d(v float64) *APIKeyUpdateOne { + _u.mutation.AddRateLimit7d(v) + return _u +} + +// SetUsage5h sets the "usage_5h" field. +func (_u *APIKeyUpdateOne) SetUsage5h(v float64) *APIKeyUpdateOne { + _u.mutation.ResetUsage5h() + _u.mutation.SetUsage5h(v) + return _u +} + +// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableUsage5h(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetUsage5h(*v) + } + return _u +} + +// AddUsage5h adds value to the "usage_5h" field. +func (_u *APIKeyUpdateOne) AddUsage5h(v float64) *APIKeyUpdateOne { + _u.mutation.AddUsage5h(v) + return _u +} + +// SetUsage1d sets the "usage_1d" field. +func (_u *APIKeyUpdateOne) SetUsage1d(v float64) *APIKeyUpdateOne { + _u.mutation.ResetUsage1d() + _u.mutation.SetUsage1d(v) + return _u +} + +// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableUsage1d(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetUsage1d(*v) + } + return _u +} + +// AddUsage1d adds value to the "usage_1d" field. +func (_u *APIKeyUpdateOne) AddUsage1d(v float64) *APIKeyUpdateOne { + _u.mutation.AddUsage1d(v) + return _u +} + +// SetUsage7d sets the "usage_7d" field. +func (_u *APIKeyUpdateOne) SetUsage7d(v float64) *APIKeyUpdateOne { + _u.mutation.ResetUsage7d() + _u.mutation.SetUsage7d(v) + return _u +} + +// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableUsage7d(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetUsage7d(*v) + } + return _u +} + +// AddUsage7d adds value to the "usage_7d" field. +func (_u *APIKeyUpdateOne) AddUsage7d(v float64) *APIKeyUpdateOne { + _u.mutation.AddUsage7d(v) + return _u +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (_u *APIKeyUpdateOne) SetWindow5hStart(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetWindow5hStart(v) + return _u +} + +// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableWindow5hStart(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetWindow5hStart(*v) + } + return _u +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (_u *APIKeyUpdateOne) ClearWindow5hStart() *APIKeyUpdateOne { + _u.mutation.ClearWindow5hStart() + return _u +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (_u *APIKeyUpdateOne) SetWindow1dStart(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetWindow1dStart(v) + return _u +} + +// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableWindow1dStart(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetWindow1dStart(*v) + } + return _u +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (_u *APIKeyUpdateOne) ClearWindow1dStart() *APIKeyUpdateOne { + _u.mutation.ClearWindow1dStart() + return _u +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (_u *APIKeyUpdateOne) SetWindow7dStart(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetWindow7dStart(v) + return _u +} + +// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableWindow7dStart(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetWindow7dStart(*v) + } + return _u +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (_u *APIKeyUpdateOne) ClearWindow7dStart() *APIKeyUpdateOne { + _u.mutation.ClearWindow7dStart() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne { return _u.SetUserID(v.ID) @@ -1033,6 +1459,60 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro if _u.mutation.ExpiresAtCleared() { _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) } + if value, ok := _u.mutation.RateLimit5h(); ok { + _spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit5h(); ok { + _spec.AddField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateLimit1d(); ok { + _spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit1d(); ok { + _spec.AddField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateLimit7d(); ok { + _spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit7d(); ok { + _spec.AddField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage5h(); ok { + _spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage5h(); ok { + _spec.AddField(apikey.FieldUsage5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage1d(); ok { + _spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage1d(); ok { + _spec.AddField(apikey.FieldUsage1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage7d(); ok { + _spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage7d(); ok { + _spec.AddField(apikey.FieldUsage7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Window5hStart(); ok { + _spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value) + } + if _u.mutation.Window5hStartCleared() { + _spec.ClearField(apikey.FieldWindow5hStart, field.TypeTime) + } + if value, ok := _u.mutation.Window1dStart(); ok { + _spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value) + } + if _u.mutation.Window1dStartCleared() { + _spec.ClearField(apikey.FieldWindow1dStart, field.TypeTime) + } + if value, ok := _u.mutation.Window7dStart(); ok { + _spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value) + } + if _u.mutation.Window7dStartCleared() { + _spec.ClearField(apikey.FieldWindow7dStart, field.TypeTime) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/ent/group.go b/backend/ent/group.go index 76c3cae2..3db54a64 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -78,6 +78,10 @@ type Group struct { SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` // 分组显示排序,数值越小越靠前 SortOrder int `json:"sort_order,omitempty"` + // 是否允许 /v1/messages 调度到此 OpenAI 分组 + AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"` + // 默认映射模型 ID,当账号级映射找不到时使用此值 + DefaultMappedModel string `json:"default_mapped_model,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -186,13 +190,13 @@ func (*Group) scanValues(columns []string) ([]any, error) { switch columns[i] { case group.FieldModelRouting, group.FieldSupportedModelScopes: values[i] = new([]byte) - case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject: + case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch: values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: values[i] = new(sql.NullInt64) - case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: + case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel: values[i] = new(sql.NullString) case group.FieldCreatedAt, group.FieldUpdatedAt, group.FieldDeletedAt: values[i] = new(sql.NullTime) @@ -415,6 +419,18 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.SortOrder = int(value.Int64) } + case group.FieldAllowMessagesDispatch: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field allow_messages_dispatch", values[i]) + } else if value.Valid { + _m.AllowMessagesDispatch = value.Bool + } + case group.FieldDefaultMappedModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field default_mapped_model", values[i]) + } else if value.Valid { + _m.DefaultMappedModel = value.String + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -608,6 +624,12 @@ func (_m *Group) String() string { builder.WriteString(", ") builder.WriteString("sort_order=") builder.WriteString(fmt.Sprintf("%v", _m.SortOrder)) + builder.WriteString(", ") + builder.WriteString("allow_messages_dispatch=") + builder.WriteString(fmt.Sprintf("%v", _m.AllowMessagesDispatch)) + builder.WriteString(", ") + builder.WriteString("default_mapped_model=") + builder.WriteString(_m.DefaultMappedModel) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 6ac4eea1..2612b6cf 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -75,6 +75,10 @@ const ( FieldSupportedModelScopes = "supported_model_scopes" // FieldSortOrder holds the string denoting the sort_order field in the database. FieldSortOrder = "sort_order" + // FieldAllowMessagesDispatch holds the string denoting the allow_messages_dispatch field in the database. + FieldAllowMessagesDispatch = "allow_messages_dispatch" + // FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database. + FieldDefaultMappedModel = "default_mapped_model" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -180,6 +184,8 @@ var Columns = []string{ FieldMcpXMLInject, FieldSupportedModelScopes, FieldSortOrder, + FieldAllowMessagesDispatch, + FieldDefaultMappedModel, } var ( @@ -247,6 +253,12 @@ var ( DefaultSupportedModelScopes []string // DefaultSortOrder holds the default value on creation for the "sort_order" field. DefaultSortOrder int + // DefaultAllowMessagesDispatch holds the default value on creation for the "allow_messages_dispatch" field. + DefaultAllowMessagesDispatch bool + // DefaultDefaultMappedModel holds the default value on creation for the "default_mapped_model" field. + DefaultDefaultMappedModel string + // DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. + DefaultMappedModelValidator func(string) error ) // OrderOption defines the ordering options for the Group queries. @@ -397,6 +409,16 @@ func BySortOrder(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldSortOrder, opts...).ToFunc() } +// ByAllowMessagesDispatch orders the results by the allow_messages_dispatch field. +func ByAllowMessagesDispatch(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAllowMessagesDispatch, opts...).ToFunc() +} + +// ByDefaultMappedModel orders the results by the default_mapped_model field. +func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 4cf65d0f..5dd8759e 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -195,6 +195,16 @@ func SortOrder(v int) predicate.Group { return predicate.Group(sql.FieldEQ(FieldSortOrder, v)) } +// AllowMessagesDispatch applies equality check predicate on the "allow_messages_dispatch" field. It's identical to AllowMessagesDispatchEQ. +func AllowMessagesDispatch(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v)) +} + +// DefaultMappedModel applies equality check predicate on the "default_mapped_model" field. It's identical to DefaultMappedModelEQ. +func DefaultMappedModel(v string) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -1470,6 +1480,81 @@ func SortOrderLTE(v int) predicate.Group { return predicate.Group(sql.FieldLTE(FieldSortOrder, v)) } +// AllowMessagesDispatchEQ applies the EQ predicate on the "allow_messages_dispatch" field. +func AllowMessagesDispatchEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v)) +} + +// AllowMessagesDispatchNEQ applies the NEQ predicate on the "allow_messages_dispatch" field. +func AllowMessagesDispatchNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldAllowMessagesDispatch, v)) +} + +// DefaultMappedModelEQ applies the EQ predicate on the "default_mapped_model" field. +func DefaultMappedModelEQ(v string) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelNEQ applies the NEQ predicate on the "default_mapped_model" field. +func DefaultMappedModelNEQ(v string) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelIn applies the In predicate on the "default_mapped_model" field. +func DefaultMappedModelIn(vs ...string) predicate.Group { + return predicate.Group(sql.FieldIn(FieldDefaultMappedModel, vs...)) +} + +// DefaultMappedModelNotIn applies the NotIn predicate on the "default_mapped_model" field. +func DefaultMappedModelNotIn(vs ...string) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldDefaultMappedModel, vs...)) +} + +// DefaultMappedModelGT applies the GT predicate on the "default_mapped_model" field. +func DefaultMappedModelGT(v string) predicate.Group { + return predicate.Group(sql.FieldGT(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelGTE applies the GTE predicate on the "default_mapped_model" field. +func DefaultMappedModelGTE(v string) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelLT applies the LT predicate on the "default_mapped_model" field. +func DefaultMappedModelLT(v string) predicate.Group { + return predicate.Group(sql.FieldLT(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelLTE applies the LTE predicate on the "default_mapped_model" field. +func DefaultMappedModelLTE(v string) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelContains applies the Contains predicate on the "default_mapped_model" field. +func DefaultMappedModelContains(v string) predicate.Group { + return predicate.Group(sql.FieldContains(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelHasPrefix applies the HasPrefix predicate on the "default_mapped_model" field. +func DefaultMappedModelHasPrefix(v string) predicate.Group { + return predicate.Group(sql.FieldHasPrefix(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelHasSuffix applies the HasSuffix predicate on the "default_mapped_model" field. +func DefaultMappedModelHasSuffix(v string) predicate.Group { + return predicate.Group(sql.FieldHasSuffix(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelEqualFold applies the EqualFold predicate on the "default_mapped_model" field. +func DefaultMappedModelEqualFold(v string) predicate.Group { + return predicate.Group(sql.FieldEqualFold(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelContainsFold applies the ContainsFold predicate on the "default_mapped_model" field. +func DefaultMappedModelContainsFold(v string) predicate.Group { + return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 0ce5f959..6db5b974 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -424,6 +424,34 @@ func (_c *GroupCreate) SetNillableSortOrder(v *int) *GroupCreate { return _c } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (_c *GroupCreate) SetAllowMessagesDispatch(v bool) *GroupCreate { + _c.mutation.SetAllowMessagesDispatch(v) + return _c +} + +// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil. +func (_c *GroupCreate) SetNillableAllowMessagesDispatch(v *bool) *GroupCreate { + if v != nil { + _c.SetAllowMessagesDispatch(*v) + } + return _c +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (_c *GroupCreate) SetDefaultMappedModel(v string) *GroupCreate { + _c.mutation.SetDefaultMappedModel(v) + return _c +} + +// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil. +func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate { + if v != nil { + _c.SetDefaultMappedModel(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -613,6 +641,14 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultSortOrder _c.mutation.SetSortOrder(v) } + if _, ok := _c.mutation.AllowMessagesDispatch(); !ok { + v := group.DefaultAllowMessagesDispatch + _c.mutation.SetAllowMessagesDispatch(v) + } + if _, ok := _c.mutation.DefaultMappedModel(); !ok { + v := group.DefaultDefaultMappedModel + _c.mutation.SetDefaultMappedModel(v) + } return nil } @@ -683,6 +719,17 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.SortOrder(); !ok { return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "Group.sort_order"`)} } + if _, ok := _c.mutation.AllowMessagesDispatch(); !ok { + return &ValidationError{Name: "allow_messages_dispatch", err: errors.New(`ent: missing required field "Group.allow_messages_dispatch"`)} + } + if _, ok := _c.mutation.DefaultMappedModel(); !ok { + return &ValidationError{Name: "default_mapped_model", err: errors.New(`ent: missing required field "Group.default_mapped_model"`)} + } + if v, ok := _c.mutation.DefaultMappedModel(); ok { + if err := group.DefaultMappedModelValidator(v); err != nil { + return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)} + } + } return nil } @@ -830,6 +877,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldSortOrder, field.TypeInt, value) _node.SortOrder = value } + if value, ok := _c.mutation.AllowMessagesDispatch(); ok { + _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) + _node.AllowMessagesDispatch = value + } + if value, ok := _c.mutation.DefaultMappedModel(); ok { + _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) + _node.DefaultMappedModel = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1520,6 +1575,30 @@ func (u *GroupUpsert) AddSortOrder(v int) *GroupUpsert { return u } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (u *GroupUpsert) SetAllowMessagesDispatch(v bool) *GroupUpsert { + u.Set(group.FieldAllowMessagesDispatch, v) + return u +} + +// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create. +func (u *GroupUpsert) UpdateAllowMessagesDispatch() *GroupUpsert { + u.SetExcluded(group.FieldAllowMessagesDispatch) + return u +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (u *GroupUpsert) SetDefaultMappedModel(v string) *GroupUpsert { + u.Set(group.FieldDefaultMappedModel, v) + return u +} + +// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create. +func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert { + u.SetExcluded(group.FieldDefaultMappedModel) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2188,6 +2267,34 @@ func (u *GroupUpsertOne) UpdateSortOrder() *GroupUpsertOne { }) } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (u *GroupUpsertOne) SetAllowMessagesDispatch(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetAllowMessagesDispatch(v) + }) +} + +// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateAllowMessagesDispatch() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateAllowMessagesDispatch() + }) +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (u *GroupUpsertOne) SetDefaultMappedModel(v string) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetDefaultMappedModel(v) + }) +} + +// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateDefaultMappedModel() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -3022,6 +3129,34 @@ func (u *GroupUpsertBulk) UpdateSortOrder() *GroupUpsertBulk { }) } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (u *GroupUpsertBulk) SetAllowMessagesDispatch(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetAllowMessagesDispatch(v) + }) +} + +// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateAllowMessagesDispatch() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateAllowMessagesDispatch() + }) +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (u *GroupUpsertBulk) SetDefaultMappedModel(v string) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetDefaultMappedModel(v) + }) +} + +// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateDefaultMappedModel() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 85575292..b3698596 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -625,6 +625,34 @@ func (_u *GroupUpdate) AddSortOrder(v int) *GroupUpdate { return _u } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (_u *GroupUpdate) SetAllowMessagesDispatch(v bool) *GroupUpdate { + _u.mutation.SetAllowMessagesDispatch(v) + return _u +} + +// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate { + if v != nil { + _u.SetAllowMessagesDispatch(*v) + } + return _u +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (_u *GroupUpdate) SetDefaultMappedModel(v string) *GroupUpdate { + _u.mutation.SetDefaultMappedModel(v) + return _u +} + +// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate { + if v != nil { + _u.SetDefaultMappedModel(*v) + } + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -910,6 +938,11 @@ func (_u *GroupUpdate) check() error { return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)} } } + if v, ok := _u.mutation.DefaultMappedModel(); ok { + if err := group.DefaultMappedModelValidator(v); err != nil { + return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)} + } + } return nil } @@ -1110,6 +1143,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AddedSortOrder(); ok { _spec.AddField(group.FieldSortOrder, field.TypeInt, value) } + if value, ok := _u.mutation.AllowMessagesDispatch(); ok { + _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) + } + if value, ok := _u.mutation.DefaultMappedModel(); ok { + _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -2014,6 +2053,34 @@ func (_u *GroupUpdateOne) AddSortOrder(v int) *GroupUpdateOne { return _u } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (_u *GroupUpdateOne) SetAllowMessagesDispatch(v bool) *GroupUpdateOne { + _u.mutation.SetAllowMessagesDispatch(v) + return _u +} + +// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetAllowMessagesDispatch(*v) + } + return _u +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (_u *GroupUpdateOne) SetDefaultMappedModel(v string) *GroupUpdateOne { + _u.mutation.SetDefaultMappedModel(v) + return _u +} + +// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateOne { + if v != nil { + _u.SetDefaultMappedModel(*v) + } + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2312,6 +2379,11 @@ func (_u *GroupUpdateOne) check() error { return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)} } } + if v, ok := _u.mutation.DefaultMappedModel(); ok { + if err := group.DefaultMappedModelValidator(v); err != nil { + return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)} + } + } return nil } @@ -2529,6 +2601,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.AddedSortOrder(); ok { _spec.AddField(group.FieldSortOrder, field.TypeInt, value) } + if value, ok := _u.mutation.AllowMessagesDispatch(); ok { + _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) + } + if value, ok := _u.mutation.DefaultMappedModel(); ok { + _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 769dddce..ff1c1b88 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -24,6 +24,15 @@ var ( {Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "quota_used", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "expires_at", Type: field.TypeTime, Nullable: true}, + {Name: "rate_limit_5h", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "rate_limit_1d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "rate_limit_7d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "usage_5h", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "usage_1d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "usage_7d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "window_5h_start", Type: field.TypeTime, Nullable: true}, + {Name: "window_1d_start", Type: field.TypeTime, Nullable: true}, + {Name: "window_7d_start", Type: field.TypeTime, Nullable: true}, {Name: "group_id", Type: field.TypeInt64, Nullable: true}, {Name: "user_id", Type: field.TypeInt64}, } @@ -35,13 +44,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "api_keys_groups_api_keys", - Columns: []*schema.Column{APIKeysColumns[13]}, + Columns: []*schema.Column{APIKeysColumns[22]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "api_keys_users_api_keys", - Columns: []*schema.Column{APIKeysColumns[14]}, + Columns: []*schema.Column{APIKeysColumns[23]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -50,12 +59,12 @@ var ( { Name: "apikey_user_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[14]}, + Columns: []*schema.Column{APIKeysColumns[23]}, }, { Name: "apikey_group_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[13]}, + Columns: []*schema.Column{APIKeysColumns[22]}, }, { Name: "apikey_status", @@ -97,6 +106,7 @@ var ( {Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "concurrency", Type: field.TypeInt, Default: 3}, + {Name: "load_factor", Type: field.TypeInt, Nullable: true}, {Name: "priority", Type: field.TypeInt, Default: 50}, {Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, @@ -123,7 +133,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "accounts_proxies_proxy", - Columns: []*schema.Column{AccountsColumns[27]}, + Columns: []*schema.Column{AccountsColumns[28]}, RefColumns: []*schema.Column{ProxiesColumns[0]}, OnDelete: schema.SetNull, }, @@ -142,52 +152,52 @@ var ( { Name: "account_status", Unique: false, - Columns: []*schema.Column{AccountsColumns[13]}, + Columns: []*schema.Column{AccountsColumns[14]}, }, { Name: "account_proxy_id", Unique: false, - Columns: []*schema.Column{AccountsColumns[27]}, + Columns: []*schema.Column{AccountsColumns[28]}, }, { Name: "account_priority", Unique: false, - Columns: []*schema.Column{AccountsColumns[11]}, + Columns: []*schema.Column{AccountsColumns[12]}, }, { Name: "account_last_used_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[15]}, + Columns: []*schema.Column{AccountsColumns[16]}, }, { Name: "account_schedulable", Unique: false, - Columns: []*schema.Column{AccountsColumns[18]}, + Columns: []*schema.Column{AccountsColumns[19]}, }, { Name: "account_rate_limited_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[19]}, + Columns: []*schema.Column{AccountsColumns[20]}, }, { Name: "account_rate_limit_reset_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[20]}, + Columns: []*schema.Column{AccountsColumns[21]}, }, { Name: "account_overload_until", Unique: false, - Columns: []*schema.Column{AccountsColumns[21]}, + Columns: []*schema.Column{AccountsColumns[22]}, }, { Name: "account_platform_priority", Unique: false, - Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]}, + Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[12]}, }, { Name: "account_priority_status", Unique: false, - Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]}, + Columns: []*schema.Column{AccountsColumns[12], AccountsColumns[14]}, }, { Name: "account_deleted_at", @@ -241,6 +251,7 @@ var ( {Name: "title", Type: field.TypeString, Size: 200}, {Name: "content", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, {Name: "status", Type: field.TypeString, Size: 20, Default: "draft"}, + {Name: "notify_mode", Type: field.TypeString, Size: 20, Default: "silent"}, {Name: "targeting", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "starts_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "ends_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, @@ -263,17 +274,17 @@ var ( { Name: "announcement_created_at", Unique: false, - Columns: []*schema.Column{AnnouncementsColumns[9]}, + Columns: []*schema.Column{AnnouncementsColumns[10]}, }, { Name: "announcement_starts_at", Unique: false, - Columns: []*schema.Column{AnnouncementsColumns[5]}, + Columns: []*schema.Column{AnnouncementsColumns[6]}, }, { Name: "announcement_ends_at", Unique: false, - Columns: []*schema.Column{AnnouncementsColumns[6]}, + Columns: []*schema.Column{AnnouncementsColumns[7]}, }, }, } @@ -397,6 +408,8 @@ var ( {Name: "mcp_xml_inject", Type: field.TypeBool, Default: true}, {Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "sort_order", Type: field.TypeInt, Default: 0}, + {Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false}, + {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 823cd389..652adcac 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -91,6 +91,21 @@ type APIKeyMutation struct { quota_used *float64 addquota_used *float64 expires_at *time.Time + rate_limit_5h *float64 + addrate_limit_5h *float64 + rate_limit_1d *float64 + addrate_limit_1d *float64 + rate_limit_7d *float64 + addrate_limit_7d *float64 + usage_5h *float64 + addusage_5h *float64 + usage_1d *float64 + addusage_1d *float64 + usage_7d *float64 + addusage_7d *float64 + window_5h_start *time.Time + window_1d_start *time.Time + window_7d_start *time.Time clearedFields map[string]struct{} user *int64 cleareduser bool @@ -856,6 +871,489 @@ func (m *APIKeyMutation) ResetExpiresAt() { delete(m.clearedFields, apikey.FieldExpiresAt) } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (m *APIKeyMutation) SetRateLimit5h(f float64) { + m.rate_limit_5h = &f + m.addrate_limit_5h = nil +} + +// RateLimit5h returns the value of the "rate_limit_5h" field in the mutation. +func (m *APIKeyMutation) RateLimit5h() (r float64, exists bool) { + v := m.rate_limit_5h + if v == nil { + return + } + return *v, true +} + +// OldRateLimit5h returns the old "rate_limit_5h" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldRateLimit5h(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateLimit5h is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateLimit5h requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateLimit5h: %w", err) + } + return oldValue.RateLimit5h, nil +} + +// AddRateLimit5h adds f to the "rate_limit_5h" field. +func (m *APIKeyMutation) AddRateLimit5h(f float64) { + if m.addrate_limit_5h != nil { + *m.addrate_limit_5h += f + } else { + m.addrate_limit_5h = &f + } +} + +// AddedRateLimit5h returns the value that was added to the "rate_limit_5h" field in this mutation. +func (m *APIKeyMutation) AddedRateLimit5h() (r float64, exists bool) { + v := m.addrate_limit_5h + if v == nil { + return + } + return *v, true +} + +// ResetRateLimit5h resets all changes to the "rate_limit_5h" field. +func (m *APIKeyMutation) ResetRateLimit5h() { + m.rate_limit_5h = nil + m.addrate_limit_5h = nil +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (m *APIKeyMutation) SetRateLimit1d(f float64) { + m.rate_limit_1d = &f + m.addrate_limit_1d = nil +} + +// RateLimit1d returns the value of the "rate_limit_1d" field in the mutation. +func (m *APIKeyMutation) RateLimit1d() (r float64, exists bool) { + v := m.rate_limit_1d + if v == nil { + return + } + return *v, true +} + +// OldRateLimit1d returns the old "rate_limit_1d" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldRateLimit1d(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateLimit1d is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateLimit1d requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateLimit1d: %w", err) + } + return oldValue.RateLimit1d, nil +} + +// AddRateLimit1d adds f to the "rate_limit_1d" field. +func (m *APIKeyMutation) AddRateLimit1d(f float64) { + if m.addrate_limit_1d != nil { + *m.addrate_limit_1d += f + } else { + m.addrate_limit_1d = &f + } +} + +// AddedRateLimit1d returns the value that was added to the "rate_limit_1d" field in this mutation. +func (m *APIKeyMutation) AddedRateLimit1d() (r float64, exists bool) { + v := m.addrate_limit_1d + if v == nil { + return + } + return *v, true +} + +// ResetRateLimit1d resets all changes to the "rate_limit_1d" field. +func (m *APIKeyMutation) ResetRateLimit1d() { + m.rate_limit_1d = nil + m.addrate_limit_1d = nil +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (m *APIKeyMutation) SetRateLimit7d(f float64) { + m.rate_limit_7d = &f + m.addrate_limit_7d = nil +} + +// RateLimit7d returns the value of the "rate_limit_7d" field in the mutation. +func (m *APIKeyMutation) RateLimit7d() (r float64, exists bool) { + v := m.rate_limit_7d + if v == nil { + return + } + return *v, true +} + +// OldRateLimit7d returns the old "rate_limit_7d" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldRateLimit7d(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateLimit7d is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateLimit7d requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateLimit7d: %w", err) + } + return oldValue.RateLimit7d, nil +} + +// AddRateLimit7d adds f to the "rate_limit_7d" field. +func (m *APIKeyMutation) AddRateLimit7d(f float64) { + if m.addrate_limit_7d != nil { + *m.addrate_limit_7d += f + } else { + m.addrate_limit_7d = &f + } +} + +// AddedRateLimit7d returns the value that was added to the "rate_limit_7d" field in this mutation. +func (m *APIKeyMutation) AddedRateLimit7d() (r float64, exists bool) { + v := m.addrate_limit_7d + if v == nil { + return + } + return *v, true +} + +// ResetRateLimit7d resets all changes to the "rate_limit_7d" field. +func (m *APIKeyMutation) ResetRateLimit7d() { + m.rate_limit_7d = nil + m.addrate_limit_7d = nil +} + +// SetUsage5h sets the "usage_5h" field. +func (m *APIKeyMutation) SetUsage5h(f float64) { + m.usage_5h = &f + m.addusage_5h = nil +} + +// Usage5h returns the value of the "usage_5h" field in the mutation. +func (m *APIKeyMutation) Usage5h() (r float64, exists bool) { + v := m.usage_5h + if v == nil { + return + } + return *v, true +} + +// OldUsage5h returns the old "usage_5h" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldUsage5h(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUsage5h is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUsage5h requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUsage5h: %w", err) + } + return oldValue.Usage5h, nil +} + +// AddUsage5h adds f to the "usage_5h" field. +func (m *APIKeyMutation) AddUsage5h(f float64) { + if m.addusage_5h != nil { + *m.addusage_5h += f + } else { + m.addusage_5h = &f + } +} + +// AddedUsage5h returns the value that was added to the "usage_5h" field in this mutation. +func (m *APIKeyMutation) AddedUsage5h() (r float64, exists bool) { + v := m.addusage_5h + if v == nil { + return + } + return *v, true +} + +// ResetUsage5h resets all changes to the "usage_5h" field. +func (m *APIKeyMutation) ResetUsage5h() { + m.usage_5h = nil + m.addusage_5h = nil +} + +// SetUsage1d sets the "usage_1d" field. +func (m *APIKeyMutation) SetUsage1d(f float64) { + m.usage_1d = &f + m.addusage_1d = nil +} + +// Usage1d returns the value of the "usage_1d" field in the mutation. +func (m *APIKeyMutation) Usage1d() (r float64, exists bool) { + v := m.usage_1d + if v == nil { + return + } + return *v, true +} + +// OldUsage1d returns the old "usage_1d" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldUsage1d(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUsage1d is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUsage1d requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUsage1d: %w", err) + } + return oldValue.Usage1d, nil +} + +// AddUsage1d adds f to the "usage_1d" field. +func (m *APIKeyMutation) AddUsage1d(f float64) { + if m.addusage_1d != nil { + *m.addusage_1d += f + } else { + m.addusage_1d = &f + } +} + +// AddedUsage1d returns the value that was added to the "usage_1d" field in this mutation. +func (m *APIKeyMutation) AddedUsage1d() (r float64, exists bool) { + v := m.addusage_1d + if v == nil { + return + } + return *v, true +} + +// ResetUsage1d resets all changes to the "usage_1d" field. +func (m *APIKeyMutation) ResetUsage1d() { + m.usage_1d = nil + m.addusage_1d = nil +} + +// SetUsage7d sets the "usage_7d" field. +func (m *APIKeyMutation) SetUsage7d(f float64) { + m.usage_7d = &f + m.addusage_7d = nil +} + +// Usage7d returns the value of the "usage_7d" field in the mutation. +func (m *APIKeyMutation) Usage7d() (r float64, exists bool) { + v := m.usage_7d + if v == nil { + return + } + return *v, true +} + +// OldUsage7d returns the old "usage_7d" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldUsage7d(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUsage7d is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUsage7d requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUsage7d: %w", err) + } + return oldValue.Usage7d, nil +} + +// AddUsage7d adds f to the "usage_7d" field. +func (m *APIKeyMutation) AddUsage7d(f float64) { + if m.addusage_7d != nil { + *m.addusage_7d += f + } else { + m.addusage_7d = &f + } +} + +// AddedUsage7d returns the value that was added to the "usage_7d" field in this mutation. +func (m *APIKeyMutation) AddedUsage7d() (r float64, exists bool) { + v := m.addusage_7d + if v == nil { + return + } + return *v, true +} + +// ResetUsage7d resets all changes to the "usage_7d" field. +func (m *APIKeyMutation) ResetUsage7d() { + m.usage_7d = nil + m.addusage_7d = nil +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (m *APIKeyMutation) SetWindow5hStart(t time.Time) { + m.window_5h_start = &t +} + +// Window5hStart returns the value of the "window_5h_start" field in the mutation. +func (m *APIKeyMutation) Window5hStart() (r time.Time, exists bool) { + v := m.window_5h_start + if v == nil { + return + } + return *v, true +} + +// OldWindow5hStart returns the old "window_5h_start" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldWindow5hStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWindow5hStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWindow5hStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWindow5hStart: %w", err) + } + return oldValue.Window5hStart, nil +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (m *APIKeyMutation) ClearWindow5hStart() { + m.window_5h_start = nil + m.clearedFields[apikey.FieldWindow5hStart] = struct{}{} +} + +// Window5hStartCleared returns if the "window_5h_start" field was cleared in this mutation. +func (m *APIKeyMutation) Window5hStartCleared() bool { + _, ok := m.clearedFields[apikey.FieldWindow5hStart] + return ok +} + +// ResetWindow5hStart resets all changes to the "window_5h_start" field. +func (m *APIKeyMutation) ResetWindow5hStart() { + m.window_5h_start = nil + delete(m.clearedFields, apikey.FieldWindow5hStart) +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (m *APIKeyMutation) SetWindow1dStart(t time.Time) { + m.window_1d_start = &t +} + +// Window1dStart returns the value of the "window_1d_start" field in the mutation. +func (m *APIKeyMutation) Window1dStart() (r time.Time, exists bool) { + v := m.window_1d_start + if v == nil { + return + } + return *v, true +} + +// OldWindow1dStart returns the old "window_1d_start" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldWindow1dStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWindow1dStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWindow1dStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWindow1dStart: %w", err) + } + return oldValue.Window1dStart, nil +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (m *APIKeyMutation) ClearWindow1dStart() { + m.window_1d_start = nil + m.clearedFields[apikey.FieldWindow1dStart] = struct{}{} +} + +// Window1dStartCleared returns if the "window_1d_start" field was cleared in this mutation. +func (m *APIKeyMutation) Window1dStartCleared() bool { + _, ok := m.clearedFields[apikey.FieldWindow1dStart] + return ok +} + +// ResetWindow1dStart resets all changes to the "window_1d_start" field. +func (m *APIKeyMutation) ResetWindow1dStart() { + m.window_1d_start = nil + delete(m.clearedFields, apikey.FieldWindow1dStart) +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (m *APIKeyMutation) SetWindow7dStart(t time.Time) { + m.window_7d_start = &t +} + +// Window7dStart returns the value of the "window_7d_start" field in the mutation. +func (m *APIKeyMutation) Window7dStart() (r time.Time, exists bool) { + v := m.window_7d_start + if v == nil { + return + } + return *v, true +} + +// OldWindow7dStart returns the old "window_7d_start" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldWindow7dStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWindow7dStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWindow7dStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWindow7dStart: %w", err) + } + return oldValue.Window7dStart, nil +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (m *APIKeyMutation) ClearWindow7dStart() { + m.window_7d_start = nil + m.clearedFields[apikey.FieldWindow7dStart] = struct{}{} +} + +// Window7dStartCleared returns if the "window_7d_start" field was cleared in this mutation. +func (m *APIKeyMutation) Window7dStartCleared() bool { + _, ok := m.clearedFields[apikey.FieldWindow7dStart] + return ok +} + +// ResetWindow7dStart resets all changes to the "window_7d_start" field. +func (m *APIKeyMutation) ResetWindow7dStart() { + m.window_7d_start = nil + delete(m.clearedFields, apikey.FieldWindow7dStart) +} + // ClearUser clears the "user" edge to the User entity. func (m *APIKeyMutation) ClearUser() { m.cleareduser = true @@ -998,7 +1496,7 @@ func (m *APIKeyMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *APIKeyMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 23) if m.created_at != nil { fields = append(fields, apikey.FieldCreatedAt) } @@ -1041,6 +1539,33 @@ func (m *APIKeyMutation) Fields() []string { if m.expires_at != nil { fields = append(fields, apikey.FieldExpiresAt) } + if m.rate_limit_5h != nil { + fields = append(fields, apikey.FieldRateLimit5h) + } + if m.rate_limit_1d != nil { + fields = append(fields, apikey.FieldRateLimit1d) + } + if m.rate_limit_7d != nil { + fields = append(fields, apikey.FieldRateLimit7d) + } + if m.usage_5h != nil { + fields = append(fields, apikey.FieldUsage5h) + } + if m.usage_1d != nil { + fields = append(fields, apikey.FieldUsage1d) + } + if m.usage_7d != nil { + fields = append(fields, apikey.FieldUsage7d) + } + if m.window_5h_start != nil { + fields = append(fields, apikey.FieldWindow5hStart) + } + if m.window_1d_start != nil { + fields = append(fields, apikey.FieldWindow1dStart) + } + if m.window_7d_start != nil { + fields = append(fields, apikey.FieldWindow7dStart) + } return fields } @@ -1077,6 +1602,24 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) { return m.QuotaUsed() case apikey.FieldExpiresAt: return m.ExpiresAt() + case apikey.FieldRateLimit5h: + return m.RateLimit5h() + case apikey.FieldRateLimit1d: + return m.RateLimit1d() + case apikey.FieldRateLimit7d: + return m.RateLimit7d() + case apikey.FieldUsage5h: + return m.Usage5h() + case apikey.FieldUsage1d: + return m.Usage1d() + case apikey.FieldUsage7d: + return m.Usage7d() + case apikey.FieldWindow5hStart: + return m.Window5hStart() + case apikey.FieldWindow1dStart: + return m.Window1dStart() + case apikey.FieldWindow7dStart: + return m.Window7dStart() } return nil, false } @@ -1114,6 +1657,24 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldQuotaUsed(ctx) case apikey.FieldExpiresAt: return m.OldExpiresAt(ctx) + case apikey.FieldRateLimit5h: + return m.OldRateLimit5h(ctx) + case apikey.FieldRateLimit1d: + return m.OldRateLimit1d(ctx) + case apikey.FieldRateLimit7d: + return m.OldRateLimit7d(ctx) + case apikey.FieldUsage5h: + return m.OldUsage5h(ctx) + case apikey.FieldUsage1d: + return m.OldUsage1d(ctx) + case apikey.FieldUsage7d: + return m.OldUsage7d(ctx) + case apikey.FieldWindow5hStart: + return m.OldWindow5hStart(ctx) + case apikey.FieldWindow1dStart: + return m.OldWindow1dStart(ctx) + case apikey.FieldWindow7dStart: + return m.OldWindow7dStart(ctx) } return nil, fmt.Errorf("unknown APIKey field %s", name) } @@ -1221,6 +1782,69 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { } m.SetExpiresAt(v) return nil + case apikey.FieldRateLimit5h: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRateLimit5h(v) + return nil + case apikey.FieldRateLimit1d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRateLimit1d(v) + return nil + case apikey.FieldRateLimit7d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRateLimit7d(v) + return nil + case apikey.FieldUsage5h: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUsage5h(v) + return nil + case apikey.FieldUsage1d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUsage1d(v) + return nil + case apikey.FieldUsage7d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUsage7d(v) + return nil + case apikey.FieldWindow5hStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWindow5hStart(v) + return nil + case apikey.FieldWindow1dStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWindow1dStart(v) + return nil + case apikey.FieldWindow7dStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWindow7dStart(v) + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -1235,6 +1859,24 @@ func (m *APIKeyMutation) AddedFields() []string { if m.addquota_used != nil { fields = append(fields, apikey.FieldQuotaUsed) } + if m.addrate_limit_5h != nil { + fields = append(fields, apikey.FieldRateLimit5h) + } + if m.addrate_limit_1d != nil { + fields = append(fields, apikey.FieldRateLimit1d) + } + if m.addrate_limit_7d != nil { + fields = append(fields, apikey.FieldRateLimit7d) + } + if m.addusage_5h != nil { + fields = append(fields, apikey.FieldUsage5h) + } + if m.addusage_1d != nil { + fields = append(fields, apikey.FieldUsage1d) + } + if m.addusage_7d != nil { + fields = append(fields, apikey.FieldUsage7d) + } return fields } @@ -1247,6 +1889,18 @@ func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) { return m.AddedQuota() case apikey.FieldQuotaUsed: return m.AddedQuotaUsed() + case apikey.FieldRateLimit5h: + return m.AddedRateLimit5h() + case apikey.FieldRateLimit1d: + return m.AddedRateLimit1d() + case apikey.FieldRateLimit7d: + return m.AddedRateLimit7d() + case apikey.FieldUsage5h: + return m.AddedUsage5h() + case apikey.FieldUsage1d: + return m.AddedUsage1d() + case apikey.FieldUsage7d: + return m.AddedUsage7d() } return nil, false } @@ -1270,6 +1924,48 @@ func (m *APIKeyMutation) AddField(name string, value ent.Value) error { } m.AddQuotaUsed(v) return nil + case apikey.FieldRateLimit5h: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRateLimit5h(v) + return nil + case apikey.FieldRateLimit1d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRateLimit1d(v) + return nil + case apikey.FieldRateLimit7d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRateLimit7d(v) + return nil + case apikey.FieldUsage5h: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddUsage5h(v) + return nil + case apikey.FieldUsage1d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddUsage1d(v) + return nil + case apikey.FieldUsage7d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddUsage7d(v) + return nil } return fmt.Errorf("unknown APIKey numeric field %s", name) } @@ -1296,6 +1992,15 @@ func (m *APIKeyMutation) ClearedFields() []string { if m.FieldCleared(apikey.FieldExpiresAt) { fields = append(fields, apikey.FieldExpiresAt) } + if m.FieldCleared(apikey.FieldWindow5hStart) { + fields = append(fields, apikey.FieldWindow5hStart) + } + if m.FieldCleared(apikey.FieldWindow1dStart) { + fields = append(fields, apikey.FieldWindow1dStart) + } + if m.FieldCleared(apikey.FieldWindow7dStart) { + fields = append(fields, apikey.FieldWindow7dStart) + } return fields } @@ -1328,6 +2033,15 @@ func (m *APIKeyMutation) ClearField(name string) error { case apikey.FieldExpiresAt: m.ClearExpiresAt() return nil + case apikey.FieldWindow5hStart: + m.ClearWindow5hStart() + return nil + case apikey.FieldWindow1dStart: + m.ClearWindow1dStart() + return nil + case apikey.FieldWindow7dStart: + m.ClearWindow7dStart() + return nil } return fmt.Errorf("unknown APIKey nullable field %s", name) } @@ -1378,6 +2092,33 @@ func (m *APIKeyMutation) ResetField(name string) error { case apikey.FieldExpiresAt: m.ResetExpiresAt() return nil + case apikey.FieldRateLimit5h: + m.ResetRateLimit5h() + return nil + case apikey.FieldRateLimit1d: + m.ResetRateLimit1d() + return nil + case apikey.FieldRateLimit7d: + m.ResetRateLimit7d() + return nil + case apikey.FieldUsage5h: + m.ResetUsage5h() + return nil + case apikey.FieldUsage1d: + m.ResetUsage1d() + return nil + case apikey.FieldUsage7d: + m.ResetUsage7d() + return nil + case apikey.FieldWindow5hStart: + m.ResetWindow5hStart() + return nil + case apikey.FieldWindow1dStart: + m.ResetWindow1dStart() + return nil + case apikey.FieldWindow7dStart: + m.ResetWindow7dStart() + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -1519,6 +2260,8 @@ type AccountMutation struct { extra *map[string]interface{} concurrency *int addconcurrency *int + load_factor *int + addload_factor *int priority *int addpriority *int rate_multiplier *float64 @@ -2104,6 +2847,76 @@ func (m *AccountMutation) ResetConcurrency() { m.addconcurrency = nil } +// SetLoadFactor sets the "load_factor" field. +func (m *AccountMutation) SetLoadFactor(i int) { + m.load_factor = &i + m.addload_factor = nil +} + +// LoadFactor returns the value of the "load_factor" field in the mutation. +func (m *AccountMutation) LoadFactor() (r int, exists bool) { + v := m.load_factor + if v == nil { + return + } + return *v, true +} + +// OldLoadFactor returns the old "load_factor" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldLoadFactor(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLoadFactor is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLoadFactor requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLoadFactor: %w", err) + } + return oldValue.LoadFactor, nil +} + +// AddLoadFactor adds i to the "load_factor" field. +func (m *AccountMutation) AddLoadFactor(i int) { + if m.addload_factor != nil { + *m.addload_factor += i + } else { + m.addload_factor = &i + } +} + +// AddedLoadFactor returns the value that was added to the "load_factor" field in this mutation. +func (m *AccountMutation) AddedLoadFactor() (r int, exists bool) { + v := m.addload_factor + if v == nil { + return + } + return *v, true +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (m *AccountMutation) ClearLoadFactor() { + m.load_factor = nil + m.addload_factor = nil + m.clearedFields[account.FieldLoadFactor] = struct{}{} +} + +// LoadFactorCleared returns if the "load_factor" field was cleared in this mutation. +func (m *AccountMutation) LoadFactorCleared() bool { + _, ok := m.clearedFields[account.FieldLoadFactor] + return ok +} + +// ResetLoadFactor resets all changes to the "load_factor" field. +func (m *AccountMutation) ResetLoadFactor() { + m.load_factor = nil + m.addload_factor = nil + delete(m.clearedFields, account.FieldLoadFactor) +} + // SetPriority sets the "priority" field. func (m *AccountMutation) SetPriority(i int) { m.priority = &i @@ -3032,7 +3845,7 @@ func (m *AccountMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AccountMutation) Fields() []string { - fields := make([]string, 0, 27) + fields := make([]string, 0, 28) if m.created_at != nil { fields = append(fields, account.FieldCreatedAt) } @@ -3066,6 +3879,9 @@ func (m *AccountMutation) Fields() []string { if m.concurrency != nil { fields = append(fields, account.FieldConcurrency) } + if m.load_factor != nil { + fields = append(fields, account.FieldLoadFactor) + } if m.priority != nil { fields = append(fields, account.FieldPriority) } @@ -3144,6 +3960,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) { return m.ProxyID() case account.FieldConcurrency: return m.Concurrency() + case account.FieldLoadFactor: + return m.LoadFactor() case account.FieldPriority: return m.Priority() case account.FieldRateMultiplier: @@ -3207,6 +4025,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldProxyID(ctx) case account.FieldConcurrency: return m.OldConcurrency(ctx) + case account.FieldLoadFactor: + return m.OldLoadFactor(ctx) case account.FieldPriority: return m.OldPriority(ctx) case account.FieldRateMultiplier: @@ -3325,6 +4145,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error { } m.SetConcurrency(v) return nil + case account.FieldLoadFactor: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLoadFactor(v) + return nil case account.FieldPriority: v, ok := value.(int) if !ok { @@ -3448,6 +4275,9 @@ func (m *AccountMutation) AddedFields() []string { if m.addconcurrency != nil { fields = append(fields, account.FieldConcurrency) } + if m.addload_factor != nil { + fields = append(fields, account.FieldLoadFactor) + } if m.addpriority != nil { fields = append(fields, account.FieldPriority) } @@ -3464,6 +4294,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) { switch name { case account.FieldConcurrency: return m.AddedConcurrency() + case account.FieldLoadFactor: + return m.AddedLoadFactor() case account.FieldPriority: return m.AddedPriority() case account.FieldRateMultiplier: @@ -3484,6 +4316,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error { } m.AddConcurrency(v) return nil + case account.FieldLoadFactor: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddLoadFactor(v) + return nil case account.FieldPriority: v, ok := value.(int) if !ok { @@ -3515,6 +4354,9 @@ func (m *AccountMutation) ClearedFields() []string { if m.FieldCleared(account.FieldProxyID) { fields = append(fields, account.FieldProxyID) } + if m.FieldCleared(account.FieldLoadFactor) { + fields = append(fields, account.FieldLoadFactor) + } if m.FieldCleared(account.FieldErrorMessage) { fields = append(fields, account.FieldErrorMessage) } @@ -3571,6 +4413,9 @@ func (m *AccountMutation) ClearField(name string) error { case account.FieldProxyID: m.ClearProxyID() return nil + case account.FieldLoadFactor: + m.ClearLoadFactor() + return nil case account.FieldErrorMessage: m.ClearErrorMessage() return nil @@ -3645,6 +4490,9 @@ func (m *AccountMutation) ResetField(name string) error { case account.FieldConcurrency: m.ResetConcurrency() return nil + case account.FieldLoadFactor: + m.ResetLoadFactor() + return nil case account.FieldPriority: m.ResetPriority() return nil @@ -4319,6 +5167,7 @@ type AnnouncementMutation struct { title *string content *string status *string + notify_mode *string targeting *domain.AnnouncementTargeting starts_at *time.Time ends_at *time.Time @@ -4543,6 +5392,42 @@ func (m *AnnouncementMutation) ResetStatus() { m.status = nil } +// SetNotifyMode sets the "notify_mode" field. +func (m *AnnouncementMutation) SetNotifyMode(s string) { + m.notify_mode = &s +} + +// NotifyMode returns the value of the "notify_mode" field in the mutation. +func (m *AnnouncementMutation) NotifyMode() (r string, exists bool) { + v := m.notify_mode + if v == nil { + return + } + return *v, true +} + +// OldNotifyMode returns the old "notify_mode" field's value of the Announcement entity. +// If the Announcement object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AnnouncementMutation) OldNotifyMode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldNotifyMode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldNotifyMode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldNotifyMode: %w", err) + } + return oldValue.NotifyMode, nil +} + +// ResetNotifyMode resets all changes to the "notify_mode" field. +func (m *AnnouncementMutation) ResetNotifyMode() { + m.notify_mode = nil +} + // SetTargeting sets the "targeting" field. func (m *AnnouncementMutation) SetTargeting(dt domain.AnnouncementTargeting) { m.targeting = &dt @@ -4990,7 +5875,7 @@ func (m *AnnouncementMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AnnouncementMutation) Fields() []string { - fields := make([]string, 0, 10) + fields := make([]string, 0, 11) if m.title != nil { fields = append(fields, announcement.FieldTitle) } @@ -5000,6 +5885,9 @@ func (m *AnnouncementMutation) Fields() []string { if m.status != nil { fields = append(fields, announcement.FieldStatus) } + if m.notify_mode != nil { + fields = append(fields, announcement.FieldNotifyMode) + } if m.targeting != nil { fields = append(fields, announcement.FieldTargeting) } @@ -5035,6 +5923,8 @@ func (m *AnnouncementMutation) Field(name string) (ent.Value, bool) { return m.Content() case announcement.FieldStatus: return m.Status() + case announcement.FieldNotifyMode: + return m.NotifyMode() case announcement.FieldTargeting: return m.Targeting() case announcement.FieldStartsAt: @@ -5064,6 +5954,8 @@ func (m *AnnouncementMutation) OldField(ctx context.Context, name string) (ent.V return m.OldContent(ctx) case announcement.FieldStatus: return m.OldStatus(ctx) + case announcement.FieldNotifyMode: + return m.OldNotifyMode(ctx) case announcement.FieldTargeting: return m.OldTargeting(ctx) case announcement.FieldStartsAt: @@ -5108,6 +6000,13 @@ func (m *AnnouncementMutation) SetField(name string, value ent.Value) error { } m.SetStatus(v) return nil + case announcement.FieldNotifyMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetNotifyMode(v) + return nil case announcement.FieldTargeting: v, ok := value.(domain.AnnouncementTargeting) if !ok { @@ -5275,6 +6174,9 @@ func (m *AnnouncementMutation) ResetField(name string) error { case announcement.FieldStatus: m.ResetStatus() return nil + case announcement.FieldNotifyMode: + m.ResetNotifyMode() + return nil case announcement.FieldTargeting: m.ResetTargeting() return nil @@ -7348,6 +8250,8 @@ type GroupMutation struct { appendsupported_model_scopes []string sort_order *int addsort_order *int + allow_messages_dispatch *bool + default_mapped_model *string clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -9092,6 +9996,78 @@ func (m *GroupMutation) ResetSortOrder() { m.addsort_order = nil } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (m *GroupMutation) SetAllowMessagesDispatch(b bool) { + m.allow_messages_dispatch = &b +} + +// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation. +func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) { + v := m.allow_messages_dispatch + if v == nil { + return + } + return *v, true +} + +// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err) + } + return oldValue.AllowMessagesDispatch, nil +} + +// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field. +func (m *GroupMutation) ResetAllowMessagesDispatch() { + m.allow_messages_dispatch = nil +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (m *GroupMutation) SetDefaultMappedModel(s string) { + m.default_mapped_model = &s +} + +// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation. +func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) { + v := m.default_mapped_model + if v == nil { + return + } + return *v, true +} + +// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err) + } + return oldValue.DefaultMappedModel, nil +} + +// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field. +func (m *GroupMutation) ResetDefaultMappedModel() { + m.default_mapped_model = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -9450,7 +10426,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 30) + fields := make([]string, 0, 32) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -9541,6 +10517,12 @@ func (m *GroupMutation) Fields() []string { if m.sort_order != nil { fields = append(fields, group.FieldSortOrder) } + if m.allow_messages_dispatch != nil { + fields = append(fields, group.FieldAllowMessagesDispatch) + } + if m.default_mapped_model != nil { + fields = append(fields, group.FieldDefaultMappedModel) + } return fields } @@ -9609,6 +10591,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.SupportedModelScopes() case group.FieldSortOrder: return m.SortOrder() + case group.FieldAllowMessagesDispatch: + return m.AllowMessagesDispatch() + case group.FieldDefaultMappedModel: + return m.DefaultMappedModel() } return nil, false } @@ -9678,6 +10664,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldSupportedModelScopes(ctx) case group.FieldSortOrder: return m.OldSortOrder(ctx) + case group.FieldAllowMessagesDispatch: + return m.OldAllowMessagesDispatch(ctx) + case group.FieldDefaultMappedModel: + return m.OldDefaultMappedModel(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -9897,6 +10887,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetSortOrder(v) return nil + case group.FieldAllowMessagesDispatch: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowMessagesDispatch(v) + return nil + case group.FieldDefaultMappedModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDefaultMappedModel(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -10324,6 +11328,12 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldSortOrder: m.ResetSortOrder() return nil + case group.FieldAllowMessagesDispatch: + m.ResetAllowMessagesDispatch() + return nil + case group.FieldDefaultMappedModel: + m.ResetDefaultMappedModel() + return nil } return fmt.Errorf("unknown Group field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 65531aae..b8facf36 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -102,6 +102,30 @@ func init() { apikeyDescQuotaUsed := apikeyFields[9].Descriptor() // apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field. apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64) + // apikeyDescRateLimit5h is the schema descriptor for rate_limit_5h field. + apikeyDescRateLimit5h := apikeyFields[11].Descriptor() + // apikey.DefaultRateLimit5h holds the default value on creation for the rate_limit_5h field. + apikey.DefaultRateLimit5h = apikeyDescRateLimit5h.Default.(float64) + // apikeyDescRateLimit1d is the schema descriptor for rate_limit_1d field. + apikeyDescRateLimit1d := apikeyFields[12].Descriptor() + // apikey.DefaultRateLimit1d holds the default value on creation for the rate_limit_1d field. + apikey.DefaultRateLimit1d = apikeyDescRateLimit1d.Default.(float64) + // apikeyDescRateLimit7d is the schema descriptor for rate_limit_7d field. + apikeyDescRateLimit7d := apikeyFields[13].Descriptor() + // apikey.DefaultRateLimit7d holds the default value on creation for the rate_limit_7d field. + apikey.DefaultRateLimit7d = apikeyDescRateLimit7d.Default.(float64) + // apikeyDescUsage5h is the schema descriptor for usage_5h field. + apikeyDescUsage5h := apikeyFields[14].Descriptor() + // apikey.DefaultUsage5h holds the default value on creation for the usage_5h field. + apikey.DefaultUsage5h = apikeyDescUsage5h.Default.(float64) + // apikeyDescUsage1d is the schema descriptor for usage_1d field. + apikeyDescUsage1d := apikeyFields[15].Descriptor() + // apikey.DefaultUsage1d holds the default value on creation for the usage_1d field. + apikey.DefaultUsage1d = apikeyDescUsage1d.Default.(float64) + // apikeyDescUsage7d is the schema descriptor for usage_7d field. + apikeyDescUsage7d := apikeyFields[16].Descriptor() + // apikey.DefaultUsage7d holds the default value on creation for the usage_7d field. + apikey.DefaultUsage7d = apikeyDescUsage7d.Default.(float64) accountMixin := schema.Account{}.Mixin() accountMixinHooks1 := accountMixin[1].Hooks() account.Hooks[0] = accountMixinHooks1[0] @@ -188,29 +212,29 @@ func init() { // account.DefaultConcurrency holds the default value on creation for the concurrency field. account.DefaultConcurrency = accountDescConcurrency.Default.(int) // accountDescPriority is the schema descriptor for priority field. - accountDescPriority := accountFields[8].Descriptor() + accountDescPriority := accountFields[9].Descriptor() // account.DefaultPriority holds the default value on creation for the priority field. account.DefaultPriority = accountDescPriority.Default.(int) // accountDescRateMultiplier is the schema descriptor for rate_multiplier field. - accountDescRateMultiplier := accountFields[9].Descriptor() + accountDescRateMultiplier := accountFields[10].Descriptor() // account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64) // accountDescStatus is the schema descriptor for status field. - accountDescStatus := accountFields[10].Descriptor() + accountDescStatus := accountFields[11].Descriptor() // account.DefaultStatus holds the default value on creation for the status field. account.DefaultStatus = accountDescStatus.Default.(string) // account.StatusValidator is a validator for the "status" field. It is called by the builders before save. account.StatusValidator = accountDescStatus.Validators[0].(func(string) error) // accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field. - accountDescAutoPauseOnExpired := accountFields[14].Descriptor() + accountDescAutoPauseOnExpired := accountFields[15].Descriptor() // account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field. account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool) // accountDescSchedulable is the schema descriptor for schedulable field. - accountDescSchedulable := accountFields[15].Descriptor() + accountDescSchedulable := accountFields[16].Descriptor() // account.DefaultSchedulable holds the default value on creation for the schedulable field. account.DefaultSchedulable = accountDescSchedulable.Default.(bool) // accountDescSessionWindowStatus is the schema descriptor for session_window_status field. - accountDescSessionWindowStatus := accountFields[23].Descriptor() + accountDescSessionWindowStatus := accountFields[24].Descriptor() // account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save. account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error) accountgroupFields := schema.AccountGroup{}.Fields() @@ -253,12 +277,18 @@ func init() { announcement.DefaultStatus = announcementDescStatus.Default.(string) // announcement.StatusValidator is a validator for the "status" field. It is called by the builders before save. announcement.StatusValidator = announcementDescStatus.Validators[0].(func(string) error) + // announcementDescNotifyMode is the schema descriptor for notify_mode field. + announcementDescNotifyMode := announcementFields[3].Descriptor() + // announcement.DefaultNotifyMode holds the default value on creation for the notify_mode field. + announcement.DefaultNotifyMode = announcementDescNotifyMode.Default.(string) + // announcement.NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save. + announcement.NotifyModeValidator = announcementDescNotifyMode.Validators[0].(func(string) error) // announcementDescCreatedAt is the schema descriptor for created_at field. - announcementDescCreatedAt := announcementFields[8].Descriptor() + announcementDescCreatedAt := announcementFields[9].Descriptor() // announcement.DefaultCreatedAt holds the default value on creation for the created_at field. announcement.DefaultCreatedAt = announcementDescCreatedAt.Default.(func() time.Time) // announcementDescUpdatedAt is the schema descriptor for updated_at field. - announcementDescUpdatedAt := announcementFields[9].Descriptor() + announcementDescUpdatedAt := announcementFields[10].Descriptor() // announcement.DefaultUpdatedAt holds the default value on creation for the updated_at field. announcement.DefaultUpdatedAt = announcementDescUpdatedAt.Default.(func() time.Time) // announcement.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. @@ -423,6 +453,16 @@ func init() { groupDescSortOrder := groupFields[26].Descriptor() // group.DefaultSortOrder holds the default value on creation for the sort_order field. group.DefaultSortOrder = groupDescSortOrder.Default.(int) + // groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field. + groupDescAllowMessagesDispatch := groupFields[27].Descriptor() + // group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field. + group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool) + // groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field. + groupDescDefaultMappedModel := groupFields[28].Descriptor() + // group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field. + group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) + // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. + group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() _ = idempotencyrecordMixinFields0 diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go index 443f9e09..5616d399 100644 --- a/backend/ent/schema/account.go +++ b/backend/ent/schema/account.go @@ -97,6 +97,8 @@ func (Account) Fields() []ent.Field { field.Int("concurrency"). Default(3), + field.Int("load_factor").Optional().Nillable(), + // priority: 账户优先级,数值越小优先级越高 // 调度器会优先使用高优先级的账户 field.Int("priority"). diff --git a/backend/ent/schema/announcement.go b/backend/ent/schema/announcement.go index 1568778f..14159fc3 100644 --- a/backend/ent/schema/announcement.go +++ b/backend/ent/schema/announcement.go @@ -41,6 +41,10 @@ func (Announcement) Fields() []ent.Field { MaxLen(20). Default(domain.AnnouncementStatusDraft). Comment("状态: draft, active, archived"), + field.String("notify_mode"). + MaxLen(20). + Default(domain.AnnouncementNotifyModeSilent). + Comment("通知模式: silent(仅铃铛), popup(弹窗提醒)"), field.JSON("targeting", domain.AnnouncementTargeting{}). Optional(). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go index c1ac7ac3..5db51270 100644 --- a/backend/ent/schema/api_key.go +++ b/backend/ent/schema/api_key.go @@ -74,6 +74,47 @@ func (APIKey) Fields() []ent.Field { Optional(). Nillable(). Comment("Expiration time for this API key (null = never expires)"), + + // ========== Rate limit fields ========== + // Rate limit configuration (0 = unlimited) + field.Float("rate_limit_5h"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Rate limit in USD per 5 hours (0 = unlimited)"), + field.Float("rate_limit_1d"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Rate limit in USD per day (0 = unlimited)"), + field.Float("rate_limit_7d"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Rate limit in USD per 7 days (0 = unlimited)"), + // Rate limit usage tracking + field.Float("usage_5h"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Used amount in USD for the current 5h window"), + field.Float("usage_1d"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Used amount in USD for the current 1d window"), + field.Float("usage_7d"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Used amount in USD for the current 7d window"), + // Window start times + field.Time("window_5h_start"). + Optional(). + Nillable(). + Comment("Start time of the current 5h rate limit window"), + field.Time("window_1d_start"). + Optional(). + Nillable(). + Comment("Start time of the current 1d rate limit window"), + field.Time("window_7d_start"). + Optional(). + Nillable(). + Comment("Start time of the current 7d rate limit window"), } } diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 3fcf8674..0f5a7b14 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -148,6 +148,15 @@ func (Group) Fields() []ent.Field { field.Int("sort_order"). Default(0). Comment("分组显示排序,数值越小越靠前"), + + // OpenAI Messages 调度配置 (added by migration 069) + field.Bool("allow_messages_dispatch"). + Default(false). + Comment("是否允许 /v1/messages 调度到此 OpenAI 分组"), + field.String("default_mapped_model"). + MaxLen(100). + Default(""). + Comment("默认映射模型 ID,当账号级映射找不到时使用此值"), } } diff --git a/backend/go.mod b/backend/go.mod index a34c9fff..135cbd3e 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -1,12 +1,13 @@ module github.com/Wei-Shaw/sub2api -go 1.25.7 +go 1.26.1 require ( entgo.io/ent v0.14.5 github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DouDOU-start/go-sora2api v1.1.0 github.com/alitto/pond/v2 v2.6.2 + github.com/aws/aws-sdk-go-v2 v1.41.3 github.com/aws/aws-sdk-go-v2/config v1.32.10 github.com/aws/aws-sdk-go-v2/credentials v1.19.10 github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 @@ -38,8 +39,6 @@ require ( golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 golang.org/x/term v0.40.0 - google.golang.org/grpc v1.75.1 - google.golang.org/protobuf v1.36.10 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.44.3 @@ -53,7 +52,6 @@ require ( github.com/agext/levenshtein v1.2.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect - github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect @@ -68,7 +66,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect - github.com/aws/smithy-go v1.24.1 // indirect + github.com/aws/smithy-go v1.24.2 // indirect github.com/bdandy/go-errors v1.2.2 // indirect github.com/bdandy/go-socks4 v1.2.3 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect @@ -109,7 +107,6 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect @@ -169,6 +166,7 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect go.opentelemetry.io/otel v1.37.0 // indirect go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/sdk v1.37.0 // indirect go.opentelemetry.io/otel/trace v1.37.0 // indirect go.uber.org/atomic v1.10.0 // indirect go.uber.org/automaxprocs v1.6.0 // indirect @@ -178,8 +176,6 @@ require ( golang.org/x/mod v0.32.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect - golang.org/x/tools v0.41.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 32e389a7..324fe652 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -24,6 +24,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls= github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4= +github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= +github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c= github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI= @@ -60,6 +62,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8 github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs= github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0= github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q= github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM= github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic= @@ -171,8 +175,6 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -182,8 +184,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= -github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= -github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -398,8 +398,6 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= -go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= -go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= @@ -455,8 +453,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 4f6fea37..e90e56af 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -30,6 +30,14 @@ const ( // __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" +// UMQ(用户消息队列)模式常量 +const ( + // UMQModeSerialize: 账号级串行锁 + RPM 自适应延迟 + UMQModeSerialize = "serialize" + // UMQModeThrottle: 仅 RPM 自适应前置延迟,不阻塞并发 + UMQModeThrottle = "throttle" +) + // 连接池隔离策略常量 // 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗 const ( @@ -265,8 +273,13 @@ type CSPConfig struct { } type ProxyFallbackConfig struct { - // AllowDirectOnError 当代理初始化失败时是否允许回退直连。 - // 默认 false:避免因代理配置错误导致 IP 泄露/关联。 + // AllowDirectOnError 当辅助服务的代理初始化失败时是否允许回退直连。 + // 仅影响以下非 AI 账号连接的辅助服务: + // - GitHub Release 更新检查 + // - 定价数据拉取 + // 不影响 AI 账号网关连接(Claude/OpenAI/Gemini/Antigravity), + // 这些关键路径的代理失败始终返回错误,不会回退直连。 + // 默认 false:避免因代理配置错误导致服务器真实 IP 泄露。 AllowDirectOnError bool `mapstructure:"allow_direct_on_error"` } @@ -450,6 +463,52 @@ type GatewayConfig struct { UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"` // ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL(秒) ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"` + + // UserMessageQueue: 用户消息串行队列配置 + // 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟 + UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"` +} + +// UserMessageQueueConfig 用户消息串行队列配置 +// 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送 +type UserMessageQueueConfig struct { + // Mode: 模式选择 + // "serialize" = 账号级串行锁 + RPM 自适应延迟 + // "throttle" = 仅 RPM 自适应前置延迟,不阻塞并发 + // "" = 禁用(默认) + Mode string `mapstructure:"mode"` + // Enabled: 已废弃,仅向后兼容(等同于 mode: "serialize") + Enabled bool `mapstructure:"enabled"` + // LockTTLMs: 串行锁 TTL(毫秒),应大于最长请求时间 + LockTTLMs int `mapstructure:"lock_ttl_ms"` + // WaitTimeoutMs: 等待获取锁的超时时间(毫秒) + WaitTimeoutMs int `mapstructure:"wait_timeout_ms"` + // MinDelayMs: RPM 自适应延迟下限(毫秒) + MinDelayMs int `mapstructure:"min_delay_ms"` + // MaxDelayMs: RPM 自适应延迟上限(毫秒) + MaxDelayMs int `mapstructure:"max_delay_ms"` + // CleanupIntervalSeconds: 孤儿锁清理间隔(秒),0 表示禁用 + CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"` +} + +// WaitTimeout 返回等待超时的 time.Duration +func (c *UserMessageQueueConfig) WaitTimeout() time.Duration { + if c.WaitTimeoutMs <= 0 { + return 30 * time.Second + } + return time.Duration(c.WaitTimeoutMs) * time.Millisecond +} + +// GetEffectiveMode 返回生效的模式 +// 注意:Mode 字段已在 load() 中做过白名单校验和规范化,此处无需重复验证 +func (c *UserMessageQueueConfig) GetEffectiveMode() string { + if c.Mode == UMQModeSerialize || c.Mode == UMQModeThrottle { + return c.Mode + } + if c.Enabled { + return UMQModeSerialize // 向后兼容 + } + return "" } // GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 @@ -457,7 +516,7 @@ type GatewayConfig struct { type GatewayOpenAIWSConfig struct { // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为) ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"` - // IngressModeDefault: ingress 默认模式(off/shared/dedicated) + // IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough) IngressModeDefault string `mapstructure:"ingress_mode_default"` // Enabled: 全局总开关(默认 true) Enabled bool `mapstructure:"enabled"` @@ -813,7 +872,8 @@ type DefaultConfig struct { } type RateLimitConfig struct { - OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) + OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) + OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401临时不可调度冷却(分钟) } // APIKeyAuthCacheConfig API Key 认证缓存配置 @@ -874,9 +934,10 @@ type DashboardAggregationConfig struct { // DashboardAggregationRetentionConfig 预聚合保留窗口 type DashboardAggregationRetentionConfig struct { - UsageLogsDays int `mapstructure:"usage_logs_days"` - HourlyDays int `mapstructure:"hourly_days"` - DailyDays int `mapstructure:"daily_days"` + UsageLogsDays int `mapstructure:"usage_logs_days"` + UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"` + HourlyDays int `mapstructure:"hourly_days"` + DailyDays int `mapstructure:"daily_days"` } // UsageCleanupConfig 使用记录清理任务配置 @@ -989,6 +1050,14 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds } + // Normalize UMQ mode: 白名单校验,非法值在加载时一次性 warn 并清空 + if m := cfg.Gateway.UserMessageQueue.Mode; m != "" && m != UMQModeSerialize && m != UMQModeThrottle { + slog.Warn("invalid user_message_queue mode, disabling", + "mode", m, + "valid_modes", []string{UMQModeSerialize, UMQModeThrottle}) + cfg.Gateway.UserMessageQueue.Mode = "" + } + // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256) cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey) if cfg.Totp.EncryptionKey == "" { @@ -1105,6 +1174,9 @@ func setDefaults() { viper.SetDefault("security.csp.policy", DefaultCSPPolicy) viper.SetDefault("security.proxy_probe.insecure_skip_verify", false) + // Security - disable direct fallback on proxy error + viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) + // Billing viper.SetDefault("billing.circuit_breaker.enabled", true) viper.SetDefault("billing.circuit_breaker.failure_threshold", 5) @@ -1156,7 +1228,7 @@ func setDefaults() { // Ops (vNext) viper.SetDefault("ops.enabled", true) - viper.SetDefault("ops.use_preaggregated_tables", false) + viper.SetDefault("ops.use_preaggregated_tables", true) viper.SetDefault("ops.cleanup.enabled", true) viper.SetDefault("ops.cleanup.schedule", "0 2 * * *") // Retention days: vNext defaults to 30 days across ops datasets. @@ -1190,6 +1262,7 @@ func setDefaults() { // RateLimit viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) + viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10) // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移) viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json") @@ -1229,6 +1302,7 @@ func setDefaults() { viper.SetDefault("dashboard_aggregation.backfill_enabled", false) viper.SetDefault("dashboard_aggregation.backfill_max_days", 31) viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90) + viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365) viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180) viper.SetDefault("dashboard_aggregation.retention.daily_days", 730) viper.SetDefault("dashboard_aggregation.recompute_days", 2) @@ -1263,7 +1337,7 @@ func setDefaults() { // OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚) viper.SetDefault("gateway.openai_ws.enabled", true) viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false) - viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared") + viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool") viper.SetDefault("gateway.openai_ws.oauth_enabled", true) viper.SetDefault("gateway.openai_ws.apikey_enabled", true) viper.SetDefault("gateway.openai_ws.force_http", false) @@ -1330,7 +1404,7 @@ func setDefaults() { viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_keepalive_interval", 10) - viper.SetDefault("gateway.max_line_size", 40*1024*1024) + viper.SetDefault("gateway.max_line_size", 500*1024*1024) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second) viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) @@ -1364,6 +1438,14 @@ func setDefaults() { viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30) viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15) // TLS指纹伪装配置(默认关闭,需要账号级别单独启用) + // 用户消息串行队列默认值 + viper.SetDefault("gateway.user_message_queue.enabled", false) + viper.SetDefault("gateway.user_message_queue.lock_ttl_ms", 120000) + viper.SetDefault("gateway.user_message_queue.wait_timeout_ms", 30000) + viper.SetDefault("gateway.user_message_queue.min_delay_ms", 200) + viper.SetDefault("gateway.user_message_queue.max_delay_ms", 2000) + viper.SetDefault("gateway.user_message_queue.cleanup_interval_seconds", 60) + viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("concurrency.ping_interval", 10) @@ -1415,9 +1497,6 @@ func setDefaults() { viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.quota.policy", "") - // Security - proxy fallback - viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) - // Subscription Maintenance (bounded queue + worker pool) viper.SetDefault("subscription_maintenance.worker_count", 2) viper.SetDefault("subscription_maintenance.queue_size", 1024) @@ -1681,6 +1760,12 @@ func (c *Config) Validate() error { if c.DashboardAgg.Retention.UsageLogsDays <= 0 { return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive") } + if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive") + } + if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days") + } if c.DashboardAgg.Retention.HourlyDays <= 0 { return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive") } @@ -1703,6 +1788,14 @@ func (c *Config) Validate() error { if c.DashboardAgg.Retention.UsageLogsDays < 0 { return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative") } + if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative") + } + if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 && + c.DashboardAgg.Retention.UsageLogsDays > 0 && + c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days") + } if c.DashboardAgg.Retention.HourlyDays < 0 { return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative") } @@ -1966,9 +2059,11 @@ func (c *Config) Validate() error { } if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" { switch mode { - case "off", "shared", "dedicated": + case "off", "ctx_pool", "passthrough": + case "shared", "dedicated": + slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode) default: - return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated") + return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough") } } if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" { diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index e3b592e2..abb76549 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) { if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled { t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false") } - if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" { - t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared") + if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" { + t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool") } } @@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { if cfg.DashboardAgg.Retention.UsageLogsDays != 90 { t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays) } + if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 { + t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays) + } if cfg.DashboardAgg.Retention.HourlyDays != 180 { t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays) } @@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 }, wantErr: "dashboard_aggregation.retention.usage_logs_days", }, + { + name: "dashboard aggregation dedup retention", + mutate: func(c *Config) { + c.DashboardAgg.Enabled = true + c.DashboardAgg.Retention.UsageBillingDedupDays = 0 + }, + wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days", + }, + { + name: "dashboard aggregation dedup retention smaller than usage logs", + mutate: func(c *Config) { + c.DashboardAgg.Enabled = true + c.DashboardAgg.Retention.UsageLogsDays = 30 + c.DashboardAgg.Retention.UsageBillingDedupDays = 29 + }, + wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days", + }, { name: "dashboard aggregation disabled interval", mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 }, @@ -1373,7 +1393,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) { wantErr: "gateway.openai_ws.store_disabled_conn_mode", }, { - name: "ingress_mode_default 必须为 off|shared|dedicated", + name: "ingress_mode_default 必须为 off|ctx_pool|passthrough", mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" }, wantErr: "gateway.openai_ws.ingress_mode_default", }, diff --git a/backend/internal/domain/announcement.go b/backend/internal/domain/announcement.go index 7dc9a9cc..0e68fb0f 100644 --- a/backend/internal/domain/announcement.go +++ b/backend/internal/domain/announcement.go @@ -13,6 +13,11 @@ const ( AnnouncementStatusArchived = "archived" ) +const ( + AnnouncementNotifyModeSilent = "silent" + AnnouncementNotifyModePopup = "popup" +) + const ( AnnouncementConditionTypeSubscription = "subscription" AnnouncementConditionTypeBalance = "balance" @@ -195,17 +200,18 @@ func (c AnnouncementCondition) validate() error { } type Announcement struct { - ID int64 - Title string - Content string - Status string - Targeting AnnouncementTargeting - StartsAt *time.Time - EndsAt *time.Time - CreatedBy *int64 - UpdatedBy *int64 - CreatedAt time.Time - UpdatedAt time.Time + ID int64 + Title string + Content string + Status string + NotifyMode string + Targeting AnnouncementTargeting + StartsAt *time.Time + EndsAt *time.Time + CreatedBy *int64 + UpdatedBy *int64 + CreatedAt time.Time + UpdatedAt time.Time } func (a *Announcement) IsActiveAt(now time.Time) bool { diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index d7bb50fc..c51046a2 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -31,6 +31,7 @@ const ( AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) AccountTypeAPIKey = "apikey" // API Key类型账号 AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) ) // Redeem type constants @@ -84,10 +85,12 @@ var DefaultAntigravityModelMapping = map[string]string{ "claude-haiku-4-5": "claude-sonnet-4-5", "claude-haiku-4-5-20251001": "claude-sonnet-4-5", // Gemini 2.5 白名单 - "gemini-2.5-flash": "gemini-2.5-flash", - "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", - "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", - "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-image": "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview": "gemini-2.5-flash-image", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", // Gemini 3 白名单 "gemini-3-flash": "gemini-3-flash", "gemini-3-pro-high": "gemini-3-pro-high", @@ -111,3 +114,27 @@ var DefaultAntigravityModelMapping = map[string]string{ "gpt-oss-120b-medium": "gpt-oss-120b-medium", "tab_flash_lite_preview": "tab_flash_lite_preview", } + +// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射 +// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID +// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的 +// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等) +var DefaultBedrockModelMapping = map[string]string{ + // Claude Opus + "claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1", + "claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1", + "claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0", + "claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0", + "claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0", + "claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0", + // Claude Sonnet + "claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6", + "claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6", + "claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0", + // Claude Haiku + "claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0", + "claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0", +} diff --git a/backend/internal/domain/constants_test.go b/backend/internal/domain/constants_test.go index 29605ac6..de66137f 100644 --- a/backend/internal/domain/constants_test.go +++ b/backend/internal/domain/constants_test.go @@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) t.Parallel() cases := map[string]string{ + "gemini-2.5-flash-image": "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview": "gemini-2.5-flash-image", "gemini-3.1-flash-image": "gemini-3.1-flash-image", "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", "gemini-3-pro-image": "gemini-3.1-flash-image", diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index 4ce17219..fbac73d3 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -8,6 +8,9 @@ import ( "strings" "time" + "log/slog" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -292,6 +295,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) } } + enrichCredentialsFromIDToken(&item) + accountInput := &service.CreateAccountInput{ Name: item.Name, Notes: item.Notes, @@ -535,6 +540,57 @@ func defaultProxyName(name string) string { return name } +// enrichCredentialsFromIDToken performs best-effort extraction of user info fields +// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials. +// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently. +// Existing credential values are never overwritten — only missing fields are filled. +func enrichCredentialsFromIDToken(item *DataAccount) { + if item.Credentials == nil { + return + } + // Only enrich OpenAI/Sora OAuth accounts + platform := strings.ToLower(strings.TrimSpace(item.Platform)) + if platform != service.PlatformOpenAI && platform != service.PlatformSora { + return + } + if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth { + return + } + + idToken, _ := item.Credentials["id_token"].(string) + if strings.TrimSpace(idToken) == "" { + return + } + + // DecodeIDToken skips expiry validation — safe for imported data + claims, err := openai.DecodeIDToken(idToken) + if err != nil { + slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err) + return + } + + userInfo := claims.GetUserInfo() + if userInfo == nil { + return + } + + // Fill missing fields only (never overwrite existing values) + setIfMissing := func(key, value string) { + if value == "" { + return + } + if existing, _ := item.Credentials[key].(string); existing == "" { + item.Credentials[key] = value + } + } + + setIfMissing("email", userInfo.Email) + setIfMissing("plan_type", userInfo.PlanType) + setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID) + setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID) + setIfMissing("organization_id", userInfo.OrganizationID) +} + func normalizeProxyStatus(status string) string { normalized := strings.TrimSpace(strings.ToLower(status)) switch normalized { diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 98ead284..3ef213e1 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/http" "strconv" "strings" @@ -18,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -95,13 +97,14 @@ type CreateAccountRequest struct { Name string `json:"name" binding:"required"` Notes *string `json:"notes"` Platform string `json:"platform" binding:"required"` - Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"` + Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"` Credentials map[string]any `json:"credentials" binding:"required"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` Concurrency int `json:"concurrency"` Priority int `json:"priority"` RateMultiplier *float64 `json:"rate_multiplier"` + LoadFactor *int `json:"load_factor"` GroupIDs []int64 `json:"group_ids"` ExpiresAt *int64 `json:"expires_at"` AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` @@ -113,14 +116,15 @@ type CreateAccountRequest struct { type UpdateAccountRequest struct { Name string `json:"name"` Notes *string `json:"notes"` - Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"` + Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` Concurrency *int `json:"concurrency"` Priority *int `json:"priority"` RateMultiplier *float64 `json:"rate_multiplier"` - Status string `json:"status" binding:"omitempty,oneof=active inactive"` + LoadFactor *int `json:"load_factor"` + Status string `json:"status" binding:"omitempty,oneof=active inactive error"` GroupIDs *[]int64 `json:"group_ids"` ExpiresAt *int64 `json:"expires_at"` AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` @@ -135,6 +139,7 @@ type BulkUpdateAccountsRequest struct { Concurrency *int `json:"concurrency"` Priority *int `json:"priority"` RateMultiplier *float64 `json:"rate_multiplier"` + LoadFactor *int `json:"load_factor"` Status string `json:"status" binding:"omitempty,oneof=active inactive error"` Schedulable *bool `json:"schedulable"` GroupIDs *[]int64 `json:"group_ids"` @@ -217,6 +222,7 @@ func (h *AccountHandler) List(c *gin.Context) { if len(search) > 100 { search = search[:100] } + lite := parseBoolQueryWithDefault(c.Query("lite"), false) var groupID int64 if groupIDStr := c.Query("group"); groupIDStr != "" { @@ -235,10 +241,16 @@ func (h *AccountHandler) List(c *gin.Context) { accountIDs[i] = acc.ID } - concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs) - if err != nil { - // Log error but don't fail the request, just use 0 for all - concurrencyCounts = make(map[int64]int) + concurrencyCounts := make(map[int64]int) + var windowCosts map[int64]float64 + var activeSessions map[int64]int + var rpmCounts map[int64]int + + // 始终获取并发数(Redis ZCARD,极低开销) + if h.concurrencyService != nil { + if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil { + concurrencyCounts = cc + } } // 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能) @@ -262,12 +274,7 @@ func (h *AccountHandler) List(c *gin.Context) { } } - // 并行获取窗口费用、活跃会话数和 RPM 计数 - var windowCosts map[int64]float64 - var activeSessions map[int64]int - var rpmCounts map[int64]int - - // 获取 RPM 计数(批量查询) + // 始终获取 RPM 计数(Redis GET,极低开销) if len(rpmAccountIDs) > 0 && h.rpmCache != nil { rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs) if rpmCounts == nil { @@ -275,7 +282,7 @@ func (h *AccountHandler) List(c *gin.Context) { } } - // 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置) + // 始终获取活跃会话数(Redis ZCARD,低开销) if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil { activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts) if activeSessions == nil { @@ -283,7 +290,7 @@ func (h *AccountHandler) List(c *gin.Context) { } } - // 获取窗口费用(并行查询) + // 始终获取窗口费用(PostgreSQL 聚合查询) if len(windowCostAccountIDs) > 0 { windowCosts = make(map[int64]float64) var mu sync.Mutex @@ -344,7 +351,7 @@ func (h *AccountHandler) List(c *gin.Context) { result[i] = item } - etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search) + etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, lite) if etag != "" { c.Header("ETag", etag) c.Header("Vary", "If-None-Match") @@ -362,6 +369,7 @@ func buildAccountsListETag( total int64, page, pageSize int, platform, accountType, status, search string, + lite bool, ) string { payload := struct { Total int64 `json:"total"` @@ -371,6 +379,7 @@ func buildAccountsListETag( AccountType string `json:"type"` Status string `json:"status"` Search string `json:"search"` + Lite bool `json:"lite"` Items []AccountWithConcurrency `json:"items"` }{ Total: total, @@ -380,6 +389,7 @@ func buildAccountsListETag( AccountType: accountType, Status: status, Search: search, + Lite: lite, Items: items, } raw, err := json.Marshal(payload) @@ -501,6 +511,7 @@ func (h *AccountHandler) Create(c *gin.Context) { Concurrency: req.Concurrency, Priority: req.Priority, RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, GroupIDs: req.GroupIDs, ExpiresAt: req.ExpiresAt, AutoPauseOnExpired: req.AutoPauseOnExpired, @@ -570,6 +581,7 @@ func (h *AccountHandler) Update(c *gin.Context) { Concurrency: req.Concurrency, // 指针类型,nil 表示未提供 Priority: req.Priority, // 指针类型,nil 表示未提供 RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, Status: req.Status, GroupIDs: req.GroupIDs, ExpiresAt: req.ExpiresAt, @@ -616,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) { // TestAccountRequest represents the request body for testing an account type TestAccountRequest struct { ModelID string `json:"model_id"` + Prompt string `json:"prompt"` } type SyncFromCRSRequest struct { @@ -646,10 +659,46 @@ func (h *AccountHandler) Test(c *gin.Context) { _ = c.ShouldBindJSON(&req) // Use AccountTestService to test the account with SSE streaming - if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil { + if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil { // Error already sent via SSE, just log return } + + if h.rateLimitService != nil { + if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil { + _ = c.Error(err) + } + } +} + +// RecoverState handles unified recovery of recoverable account runtime state. +// POST /api/v1/admin/accounts/:id/recover-state +func (h *AccountHandler) RecoverState(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + + if h.rateLimitService == nil { + response.Error(c, http.StatusServiceUnavailable, "Rate limit service unavailable") + return + } + + if _, err := h.rateLimitService.RecoverAccountState(c.Request.Context(), accountID, service.AccountRecoveryOptions{ + InvalidateToken: true, + }); err != nil { + response.ErrorFrom(c, err) + return + } + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // SyncFromCRS handles syncing accounts from claude-relay-service (CRS) @@ -705,52 +754,31 @@ func (h *AccountHandler) PreviewFromCRS(c *gin.Context) { response.Success(c, result) } -// Refresh handles refreshing account credentials -// POST /api/v1/admin/accounts/:id/refresh -func (h *AccountHandler) Refresh(c *gin.Context) { - accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.BadRequest(c, "Invalid account ID") - return - } - - // Get account - account, err := h.adminService.GetAccount(c.Request.Context(), accountID) - if err != nil { - response.NotFound(c, "Account not found") - return - } - - // Only refresh OAuth-based accounts (oauth and setup-token) +// refreshSingleAccount refreshes credentials for a single OAuth account. +// Returns (updatedAccount, warning, error) where warning is used for Antigravity ProjectIDMissing scenario. +func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *service.Account) (*service.Account, string, error) { if !account.IsOAuth() { - response.BadRequest(c, "Cannot refresh non-OAuth account credentials") - return + return nil, "", infraerrors.BadRequest("NOT_OAUTH", "cannot refresh non-OAuth account") } var newCredentials map[string]any if account.IsOpenAI() { - // Use OpenAI OAuth service to refresh token - tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account) + tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account) if err != nil { - response.ErrorFrom(c, err) - return + return nil, "", err } - // Build new credentials from token info newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo) - - // Preserve non-token settings from existing credentials for k, v := range account.Credentials { if _, exists := newCredentials[k]; !exists { newCredentials[k] = v } } } else if account.Platform == service.PlatformGemini { - tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account) + tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(ctx, account) if err != nil { - response.InternalError(c, "Failed to refresh credentials: "+err.Error()) - return + return nil, "", fmt.Errorf("failed to refresh credentials: %w", err) } newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo) @@ -760,10 +788,9 @@ func (h *AccountHandler) Refresh(c *gin.Context) { } } } else if account.Platform == service.PlatformAntigravity { - tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account) + tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(ctx, account) if err != nil { - response.ErrorFrom(c, err) - return + return nil, "", err } newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo) @@ -782,37 +809,27 @@ func (h *AccountHandler) Refresh(c *gin.Context) { } // 如果 project_id 获取失败,更新凭证但不标记为 error - // LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试 if tokenInfo.ProjectIDMissing { - // 先更新凭证(token 本身刷新成功了) - _, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ + updatedAccount, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{ Credentials: newCredentials, }) if updateErr != nil { - response.InternalError(c, "Failed to update credentials: "+updateErr.Error()) - return + return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr) } - // 不标记为 error,只返回警告信息 - response.Success(c, gin.H{ - "message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)", - "warning": "missing_project_id_temporary", - }) - return + return updatedAccount, "missing_project_id_temporary", nil } // 成功获取到 project_id,如果之前是 missing_project_id 错误则清除 if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") { - if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil { - response.InternalError(c, "Failed to clear account error: "+clearErr.Error()) - return + if _, clearErr := h.adminService.ClearAccountError(ctx, account.ID); clearErr != nil { + return nil, "", fmt.Errorf("failed to clear account error: %w", clearErr) } } } else { // Use Anthropic/Claude OAuth service to refresh token - tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account) + tokenInfo, err := h.oauthService.RefreshAccountToken(ctx, account) if err != nil { - response.ErrorFrom(c, err) - return + return nil, "", err } // Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests) @@ -834,20 +851,54 @@ func (h *AccountHandler) Refresh(c *gin.Context) { } } - updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ + updatedAccount, err := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{ Credentials: newCredentials, }) + if err != nil { + return nil, "", err + } + + // 刷新成功后,清除 token 缓存,确保下次请求使用新 token + if h.tokenCacheInvalidator != nil { + if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil { + log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", updatedAccount.ID, invalidateErr) + } + } + + // OpenAI OAuth: 刷新成功后检查并设置 privacy_mode + h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount) + + return updatedAccount, "", nil +} + +// Refresh handles refreshing account credentials +// POST /api/v1/admin/accounts/:id/refresh +func (h *AccountHandler) Refresh(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + + // Get account + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.NotFound(c, "Account not found") + return + } + + updatedAccount, warning, err := h.refreshSingleAccount(c.Request.Context(), account) if err != nil { response.ErrorFrom(c, err) return } - // 刷新成功后,清除 token 缓存,确保下次请求使用新 token - if h.tokenCacheInvalidator != nil { - if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil { - // 缓存失效失败只记录日志,不影响主流程 - _ = c.Error(invalidateErr) - } + if warning == "missing_project_id_temporary" { + response.Success(c, gin.H{ + "message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)", + "warning": "missing_project_id_temporary", + }) + return } response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount)) @@ -903,14 +954,175 @@ func (h *AccountHandler) ClearError(c *gin.Context) { // 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题 if h.tokenCacheInvalidator != nil && account.IsOAuth() { if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil { - // 缓存失效失败只记录日志,不影响主流程 - _ = c.Error(invalidateErr) + log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr) } } response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } +// BatchClearError handles batch clearing account errors +// POST /api/v1/admin/accounts/batch-clear-error +func (h *AccountHandler) BatchClearError(c *gin.Context) { + var req struct { + AccountIDs []int64 `json:"account_ids"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if len(req.AccountIDs) == 0 { + response.BadRequest(c, "account_ids is required") + return + } + + ctx := c.Request.Context() + + const maxConcurrency = 10 + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(maxConcurrency) + + var mu sync.Mutex + var successCount, failedCount int + var errors []gin.H + + // 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务 + for _, id := range req.AccountIDs { + accountID := id // 闭包捕获 + g.Go(func() error { + account, err := h.adminService.ClearAccountError(gctx, accountID) + if err != nil { + mu.Lock() + failedCount++ + errors = append(errors, gin.H{ + "account_id": accountID, + "error": err.Error(), + }) + mu.Unlock() + return nil + } + + // 清除错误后,同时清除 token 缓存 + if h.tokenCacheInvalidator != nil && account.IsOAuth() { + if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(gctx, account); invalidateErr != nil { + log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr) + } + } + + mu.Lock() + successCount++ + mu.Unlock() + return nil + }) + } + + if err := g.Wait(); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "total": len(req.AccountIDs), + "success": successCount, + "failed": failedCount, + "errors": errors, + }) +} + +// BatchRefresh handles batch refreshing account credentials +// POST /api/v1/admin/accounts/batch-refresh +func (h *AccountHandler) BatchRefresh(c *gin.Context) { + var req struct { + AccountIDs []int64 `json:"account_ids"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if len(req.AccountIDs) == 0 { + response.BadRequest(c, "account_ids is required") + return + } + + ctx := c.Request.Context() + + accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // 建立已获取账号的 ID 集合,检测缺失的 ID + foundIDs := make(map[int64]bool, len(accounts)) + for _, acc := range accounts { + if acc != nil { + foundIDs[acc.ID] = true + } + } + + const maxConcurrency = 10 + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(maxConcurrency) + + var mu sync.Mutex + var successCount, failedCount int + var errors []gin.H + var warnings []gin.H + + // 将不存在的账号 ID 标记为失败 + for _, id := range req.AccountIDs { + if !foundIDs[id] { + failedCount++ + errors = append(errors, gin.H{ + "account_id": id, + "error": "account not found", + }) + } + } + + // 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务 + for _, account := range accounts { + acc := account // 闭包捕获 + if acc == nil { + continue + } + g.Go(func() error { + _, warning, err := h.refreshSingleAccount(gctx, acc) + mu.Lock() + if err != nil { + failedCount++ + errors = append(errors, gin.H{ + "account_id": acc.ID, + "error": err.Error(), + }) + } else { + successCount++ + if warning != "" { + warnings = append(warnings, gin.H{ + "account_id": acc.ID, + "warning": warning, + }) + } + } + mu.Unlock() + return nil + }) + } + + if err := g.Wait(); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "total": len(req.AccountIDs), + "success": successCount, + "failed": failedCount, + "errors": errors, + "warnings": warnings, + }) +} + // BatchCreate handles batch creating accounts // POST /api/v1/admin/accounts/batch func (h *AccountHandler) BatchCreate(c *gin.Context) { @@ -1096,6 +1308,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { req.Concurrency != nil || req.Priority != nil || req.RateMultiplier != nil || + req.LoadFactor != nil || req.Status != "" || req.Schedulable != nil || req.GroupIDs != nil || @@ -1114,6 +1327,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { Concurrency: req.Concurrency, Priority: req.Priority, RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, Status: req.Status, Schedulable: req.Schedulable, GroupIDs: req.GroupIDs, @@ -1323,6 +1537,29 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) { response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } +// ResetQuota handles resetting account quota usage +// POST /api/v1/admin/accounts/:id/reset-quota +func (h *AccountHandler) ResetQuota(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + + if err := h.adminService.ResetAccountQuota(c.Request.Context(), accountID); err != nil { + response.InternalError(c, "Failed to reset account quota: "+err.Error()) + return + } + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) +} + // GetTempUnschedulable handles getting temporary unschedulable status // GET /api/v1/admin/accounts/:id/temp-unschedulable func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) { @@ -1398,18 +1635,41 @@ func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) { return } - if len(req.AccountIDs) == 0 { + accountIDs := normalizeInt64IDList(req.AccountIDs) + if len(accountIDs) == 0 { response.Success(c, gin.H{"stats": map[string]any{}}) return } - stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs) + cacheKey := buildAccountTodayStatsBatchCacheKey(accountIDs) + if cached, ok := accountTodayStatsBatchCache.Get(cacheKey); ok { + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) { + c.Status(http.StatusNotModified) + return + } + } + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), accountIDs) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, gin.H{"stats": stats}) + payload := gin.H{"stats": stats} + cached := accountTodayStatsBatchCache.Set(cacheKey, payload) + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + } + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) } // SetSchedulableRequest represents the request body for setting schedulable status @@ -1458,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { // Handle OpenAI accounts if account.IsOpenAI() { - // For OAuth accounts: return default OpenAI models - if account.IsOAuth() { + // OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。 + if account.IsOpenAIPassthroughEnabled() { response.Success(c, openai.DefaultModels) return } - // For API Key accounts: check model_mapping mapping := account.GetModelMapping() if len(mapping) == 0 { response.Success(c, openai.DefaultModels) diff --git a/backend/internal/handler/admin/account_handler_available_models_test.go b/backend/internal/handler/admin/account_handler_available_models_test.go new file mode 100644 index 00000000..c5f1e2d8 --- /dev/null +++ b/backend/internal/handler/admin/account_handler_available_models_test.go @@ -0,0 +1,105 @@ +package admin + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type availableModelsAdminService struct { + *stubAdminService + account service.Account +} + +func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) { + if s.account.ID == id { + acc := s.account + return &acc, nil + } + return s.stubAdminService.GetAccount(context.Background(), id) +} + +func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels) + return router +} + +func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) { + svc := &availableModelsAdminService{ + stubAdminService: newStubAdminService(), + account: service.Account{ + ID: 42, + Name: "openai-oauth", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5": "gpt-5.1", + }, + }, + }, + } + router := setupAvailableModelsRouter(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Len(t, resp.Data, 1) + require.Equal(t, "gpt-5", resp.Data[0].ID) +} + +func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) { + svc := &availableModelsAdminService{ + stubAdminService: newStubAdminService(), + account: service.Account{ + ID: 43, + Name: "openai-oauth-passthrough", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5": "gpt-5.1", + }, + }, + Extra: map[string]any{ + "openai_passthrough": true, + }, + }, + } + router := setupAvailableModelsRouter(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotEmpty(t, resp.Data) + require.NotEqual(t, "gpt-5", resp.Data[0].ID) +} diff --git a/backend/internal/handler/admin/account_today_stats_cache.go b/backend/internal/handler/admin/account_today_stats_cache.go new file mode 100644 index 00000000..61922f70 --- /dev/null +++ b/backend/internal/handler/admin/account_today_stats_cache.go @@ -0,0 +1,25 @@ +package admin + +import ( + "strconv" + "strings" + "time" +) + +var accountTodayStatsBatchCache = newSnapshotCache(30 * time.Second) + +func buildAccountTodayStatsBatchCacheKey(accountIDs []int64) string { + if len(accountIDs) == 0 { + return "accounts_today_stats_empty" + } + var b strings.Builder + b.Grow(len(accountIDs) * 6) + _, _ = b.WriteString("accounts_today_stats:") + for i, id := range accountIDs { + if i > 0 { + _ = b.WriteByte(',') + } + _, _ = b.WriteString(strconv.FormatInt(id, 10)) + } + return b.String() +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index f3b99ddb..37a72cb4 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -175,6 +175,18 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p return s.apiKeys, int64(len(s.apiKeys)), nil } +func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) { + return nil, nil +} + +func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error { + return nil +} + +func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error { + return nil +} + func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) { return s.accounts, int64(len(s.accounts)), nil } @@ -425,5 +437,13 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i return nil, service.ErrAPIKeyNotFound } +func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error { + return nil +} + +func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string { + return "" +} + // Ensure stub implements interface. var _ service.AdminService = (*stubAdminService)(nil) diff --git a/backend/internal/handler/admin/announcement_handler.go b/backend/internal/handler/admin/announcement_handler.go index 0b5d0fbc..d1312bc0 100644 --- a/backend/internal/handler/admin/announcement_handler.go +++ b/backend/internal/handler/admin/announcement_handler.go @@ -27,21 +27,23 @@ func NewAnnouncementHandler(announcementService *service.AnnouncementService) *A } type CreateAnnouncementRequest struct { - Title string `json:"title" binding:"required"` - Content string `json:"content" binding:"required"` - Status string `json:"status" binding:"omitempty,oneof=draft active archived"` - Targeting service.AnnouncementTargeting `json:"targeting"` - StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate - EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + Status string `json:"status" binding:"omitempty,oneof=draft active archived"` + NotifyMode string `json:"notify_mode" binding:"omitempty,oneof=silent popup"` + Targeting service.AnnouncementTargeting `json:"targeting"` + StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate + EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never } type UpdateAnnouncementRequest struct { - Title *string `json:"title"` - Content *string `json:"content"` - Status *string `json:"status" binding:"omitempty,oneof=draft active archived"` - Targeting *service.AnnouncementTargeting `json:"targeting"` - StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear - EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear + Title *string `json:"title"` + Content *string `json:"content"` + Status *string `json:"status" binding:"omitempty,oneof=draft active archived"` + NotifyMode *string `json:"notify_mode" binding:"omitempty,oneof=silent popup"` + Targeting *service.AnnouncementTargeting `json:"targeting"` + StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear + EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear } // List handles listing announcements with filters @@ -110,11 +112,12 @@ func (h *AnnouncementHandler) Create(c *gin.Context) { } input := &service.CreateAnnouncementInput{ - Title: req.Title, - Content: req.Content, - Status: req.Status, - Targeting: req.Targeting, - ActorID: &subject.UserID, + Title: req.Title, + Content: req.Content, + Status: req.Status, + NotifyMode: req.NotifyMode, + Targeting: req.Targeting, + ActorID: &subject.UserID, } if req.StartsAt != nil && *req.StartsAt > 0 { @@ -157,11 +160,12 @@ func (h *AnnouncementHandler) Update(c *gin.Context) { } input := &service.UpdateAnnouncementInput{ - Title: req.Title, - Content: req.Content, - Status: req.Status, - Targeting: req.Targeting, - ActorID: &subject.UserID, + Title: req.Title, + Content: req.Content, + Status: req.Status, + NotifyMode: req.NotifyMode, + Targeting: req.Targeting, + ActorID: &subject.UserID, } if req.StartsAt != nil { diff --git a/backend/internal/handler/admin/backup_handler.go b/backend/internal/handler/admin/backup_handler.go new file mode 100644 index 00000000..d19713ee --- /dev/null +++ b/backend/internal/handler/admin/backup_handler.go @@ -0,0 +1,204 @@ +package admin + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type BackupHandler struct { + backupService *service.BackupService + userService *service.UserService +} + +func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler { + return &BackupHandler{ + backupService: backupService, + userService: userService, + } +} + +// ─── S3 配置 ─── + +func (h *BackupHandler) GetS3Config(c *gin.Context) { + cfg, err := h.backupService.GetS3Config(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *BackupHandler) UpdateS3Config(c *gin.Context) { + var req service.BackupS3Config + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *BackupHandler) TestS3Connection(c *gin.Context) { + var req service.BackupS3Config + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + err := h.backupService.TestS3Connection(c.Request.Context(), req) + if err != nil { + response.Success(c, gin.H{"ok": false, "message": err.Error()}) + return + } + response.Success(c, gin.H{"ok": true, "message": "connection successful"}) +} + +// ─── 定时备份 ─── + +func (h *BackupHandler) GetSchedule(c *gin.Context) { + cfg, err := h.backupService.GetSchedule(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *BackupHandler) UpdateSchedule(c *gin.Context) { + var req service.BackupScheduleConfig + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +// ─── 备份操作 ─── + +type CreateBackupRequest struct { + ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期 +} + +func (h *BackupHandler) CreateBackup(c *gin.Context) { + var req CreateBackupRequest + _ = c.ShouldBindJSON(&req) // 允许空 body + + expireDays := 14 // 默认14天过期 + if req.ExpireDays != nil { + expireDays = *req.ExpireDays + } + + record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, record) +} + +func (h *BackupHandler) ListBackups(c *gin.Context) { + records, err := h.backupService.ListBackups(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + if records == nil { + records = []service.BackupRecord{} + } + response.Success(c, gin.H{"items": records}) +} + +func (h *BackupHandler) GetBackup(c *gin.Context) { + backupID := c.Param("id") + if backupID == "" { + response.BadRequest(c, "backup ID is required") + return + } + record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, record) +} + +func (h *BackupHandler) DeleteBackup(c *gin.Context) { + backupID := c.Param("id") + if backupID == "" { + response.BadRequest(c, "backup ID is required") + return + } + if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +func (h *BackupHandler) GetDownloadURL(c *gin.Context) { + backupID := c.Param("id") + if backupID == "" { + response.BadRequest(c, "backup ID is required") + return + } + url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"url": url}) +} + +// ─── 恢复操作(需要重新输入管理员密码) ─── + +type RestoreBackupRequest struct { + Password string `json:"password" binding:"required"` +} + +func (h *BackupHandler) RestoreBackup(c *gin.Context) { + backupID := c.Param("id") + if backupID == "" { + response.BadRequest(c, "backup ID is required") + return + } + + var req RestoreBackupRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "password is required for restore operation") + return + } + + // 从上下文获取当前管理员用户 ID + sub, ok := middleware.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "unauthorized") + return + } + + // 获取管理员用户并验证密码 + user, err := h.userService.GetByID(c.Request.Context(), sub.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + if !user.CheckPassword(req.Password) { + response.BadRequest(c, "incorrect admin password") + return + } + + if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"restored": true}) +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 1d48c653..f415b48f 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -1,6 +1,7 @@ package admin import ( + "encoding/json" "errors" "strconv" "strings" @@ -248,11 +249,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { } } - trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) + trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get usage trend") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "trend": trend, @@ -320,11 +322,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { } } - stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get model statistics") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "models": stats, @@ -390,11 +393,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) { } } - stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get group statistics") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "groups": stats, @@ -415,11 +419,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) { limit = 5 } - trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit) + trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit) if err != nil { response.Error(c, 500, "Failed to get API key usage trend") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "trend": trend, @@ -441,11 +446,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) { limit = 12 } - trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit) + trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit) if err != nil { response.Error(c, 500, "Failed to get user usage trend") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "trend": trend, @@ -460,6 +466,62 @@ type BatchUsersUsageRequest struct { UserIDs []int64 `json:"user_ids" binding:"required"` } +var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute) +var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second) +var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second) + +func parseRankingLimit(raw string) int { + limit, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || limit <= 0 { + return 12 + } + if limit > 50 { + return 50 + } + return limit +} + +// GetUserSpendingRanking handles getting user spending ranking data. +// GET /api/v1/admin/dashboard/users-ranking +func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) { + startTime, endTime := parseTimeRange(c) + limit := parseRankingLimit(c.DefaultQuery("limit", "12")) + + keyRaw, _ := json.Marshal(struct { + Start string `json:"start"` + End string `json:"end"` + Limit int `json:"limit"` + }{ + Start: startTime.UTC().Format(time.RFC3339), + End: endTime.UTC().Format(time.RFC3339), + Limit: limit, + }) + cacheKey := string(keyRaw) + if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok { + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit) + if err != nil { + response.Error(c, 500, "Failed to get user spending ranking") + return + } + + payload := gin.H{ + "ranking": ranking.Ranking, + "total_actual_cost": ranking.TotalActualCost, + "total_requests": ranking.TotalRequests, + "total_tokens": ranking.TotalTokens, + "start_date": startTime.Format("2006-01-02"), + "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"), + } + dashboardUsersRankingCache.Set(cacheKey, payload) + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) +} + // GetBatchUsersUsage handles getting usage stats for multiple users // POST /api/v1/admin/dashboard/users-usage func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { @@ -469,18 +531,34 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { return } - if len(req.UserIDs) == 0 { + userIDs := normalizeInt64IDList(req.UserIDs) + if len(userIDs) == 0 { response.Success(c, gin.H{"stats": map[string]any{}}) return } - stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{}) + keyRaw, _ := json.Marshal(struct { + UserIDs []int64 `json:"user_ids"` + }{ + UserIDs: userIDs, + }) + cacheKey := string(keyRaw) + if cached, ok := dashboardBatchUsersUsageCache.Get(cacheKey); ok { + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), userIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get user usage stats") return } - response.Success(c, gin.H{"stats": stats}) + payload := gin.H{"stats": stats} + dashboardBatchUsersUsageCache.Set(cacheKey, payload) + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) } // BatchAPIKeysUsageRequest represents the request body for batch api key usage stats @@ -497,16 +575,32 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) { return } - if len(req.APIKeyIDs) == 0 { + apiKeyIDs := normalizeInt64IDList(req.APIKeyIDs) + if len(apiKeyIDs) == 0 { response.Success(c, gin.H{"stats": map[string]any{}}) return } - stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{}) + keyRaw, _ := json.Marshal(struct { + APIKeyIDs []int64 `json:"api_key_ids"` + }{ + APIKeyIDs: apiKeyIDs, + }) + cacheKey := string(keyRaw) + if cached, ok := dashboardBatchAPIKeysUsageCache.Get(cacheKey); ok { + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), apiKeyIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get API key usage stats") return } - response.Success(c, gin.H{"stats": stats}) + payload := gin.H{"stats": stats} + dashboardBatchAPIKeysUsageCache.Set(cacheKey, payload) + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) } diff --git a/backend/internal/handler/admin/dashboard_handler_cache_test.go b/backend/internal/handler/admin/dashboard_handler_cache_test.go new file mode 100644 index 00000000..ec888849 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_handler_cache_test.go @@ -0,0 +1,118 @@ +package admin + +import ( + "context" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type dashboardUsageRepoCacheProbe struct { + service.UsageLogRepository + trendCalls atomic.Int32 + usersTrendCalls atomic.Int32 +} + +func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + userID, apiKeyID, accountID, groupID int64, + model string, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.TrendDataPoint, error) { + r.trendCalls.Add(1) + return []usagestats.TrendDataPoint{{ + Date: "2026-03-11", + Requests: 1, + TotalTokens: 2, + Cost: 3, + ActualCost: 4, + }}, nil +} + +func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + limit int, +) ([]usagestats.UserUsageTrendPoint, error) { + r.usersTrendCalls.Add(1) + return []usagestats.UserUsageTrendPoint{{ + Date: "2026-03-11", + UserID: 1, + Email: "cache@test.dev", + Requests: 2, + Tokens: 20, + Cost: 2, + ActualCost: 1, + }}, nil +} + +func resetDashboardReadCachesForTest() { + dashboardTrendCache = newSnapshotCache(30 * time.Second) + dashboardUsersTrendCache = newSnapshotCache(30 * time.Second) + dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second) + dashboardModelStatsCache = newSnapshotCache(30 * time.Second) + dashboardGroupStatsCache = newSnapshotCache(30 * time.Second) + dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second) +} + +func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) { + t.Cleanup(resetDashboardReadCachesForTest) + resetDashboardReadCachesForTest() + + gin.SetMode(gin.TestMode) + repo := &dashboardUsageRepoCacheProbe{} + dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) + handler := NewDashboardHandler(dashboardSvc, nil) + router := gin.New() + router.GET("/admin/dashboard/trend", handler.GetUsageTrend) + + req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil) + rec1 := httptest.NewRecorder() + router.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code) + require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache")) + + req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil) + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code) + require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache")) + require.Equal(t, int32(1), repo.trendCalls.Load()) +} + +func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) { + t.Cleanup(resetDashboardReadCachesForTest) + resetDashboardReadCachesForTest() + + gin.SetMode(gin.TestMode) + repo := &dashboardUsageRepoCacheProbe{} + dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) + handler := NewDashboardHandler(dashboardSvc, nil) + router := gin.New() + router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend) + + req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil) + rec1 := httptest.NewRecorder() + router.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code) + require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache")) + + req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil) + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code) + require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache")) + require.Equal(t, int32(1), repo.usersTrendCalls.Load()) +} diff --git a/backend/internal/handler/admin/dashboard_handler_request_type_test.go b/backend/internal/handler/admin/dashboard_handler_request_type_test.go index 72af6b45..9aec61d4 100644 --- a/backend/internal/handler/admin/dashboard_handler_request_type_test.go +++ b/backend/internal/handler/admin/dashboard_handler_request_type_test.go @@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct { trendStream *bool modelRequestType *int16 modelStream *bool + rankingLimit int + ranking []usagestats.UserSpendingRankingItem + rankingTotal float64 } func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters( @@ -49,6 +52,20 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters( return []usagestats.ModelStat{}, nil } +func (s *dashboardUsageRepoCapture) GetUserSpendingRanking( + ctx context.Context, + startTime, endTime time.Time, + limit int, +) (*usagestats.UserSpendingRankingResponse, error) { + s.rankingLimit = limit + return &usagestats.UserSpendingRankingResponse{ + Ranking: s.ranking, + TotalActualCost: s.rankingTotal, + TotalRequests: 44, + TotalTokens: 1234, + }, nil +} + func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine { gin.SetMode(gin.TestMode) dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) @@ -56,6 +73,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng router := gin.New() router.GET("/admin/dashboard/trend", handler.GetUsageTrend) router.GET("/admin/dashboard/models", handler.GetModelStats) + router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking) return router } @@ -130,3 +148,32 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) { require.Equal(t, http.StatusBadRequest, rec.Code) } + +func TestDashboardUsersRankingLimitAndCache(t *testing.T) { + dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute) + repo := &dashboardUsageRepoCapture{ + ranking: []usagestats.UserSpendingRankingItem{ + {UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300}, + }, + rankingTotal: 88.8, + } + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 50, repo.rankingLimit) + require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8") + require.Contains(t, rec.Body.String(), "\"total_requests\":44") + require.Contains(t, rec.Body.String(), "\"total_tokens\":1234") + require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache")) + + req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil) + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + + require.Equal(t, http.StatusOK, rec2.Code) + require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache")) +} diff --git a/backend/internal/handler/admin/dashboard_query_cache.go b/backend/internal/handler/admin/dashboard_query_cache.go new file mode 100644 index 00000000..47af5117 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_query_cache.go @@ -0,0 +1,200 @@ +package admin + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" +) + +var ( + dashboardTrendCache = newSnapshotCache(30 * time.Second) + dashboardModelStatsCache = newSnapshotCache(30 * time.Second) + dashboardGroupStatsCache = newSnapshotCache(30 * time.Second) + dashboardUsersTrendCache = newSnapshotCache(30 * time.Second) + dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second) +) + +type dashboardTrendCacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Granularity string `json:"granularity"` + UserID int64 `json:"user_id"` + APIKeyID int64 `json:"api_key_id"` + AccountID int64 `json:"account_id"` + GroupID int64 `json:"group_id"` + Model string `json:"model"` + RequestType *int16 `json:"request_type"` + Stream *bool `json:"stream"` + BillingType *int8 `json:"billing_type"` +} + +type dashboardModelGroupCacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + UserID int64 `json:"user_id"` + APIKeyID int64 `json:"api_key_id"` + AccountID int64 `json:"account_id"` + GroupID int64 `json:"group_id"` + RequestType *int16 `json:"request_type"` + Stream *bool `json:"stream"` + BillingType *int8 `json:"billing_type"` +} + +type dashboardEntityTrendCacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Granularity string `json:"granularity"` + Limit int `json:"limit"` +} + +func cacheStatusValue(hit bool) string { + if hit { + return "hit" + } + return "miss" +} + +func mustMarshalDashboardCacheKey(value any) string { + raw, err := json.Marshal(value) + if err != nil { + return "" + } + return string(raw) +} + +func snapshotPayloadAs[T any](payload any) (T, error) { + typed, ok := payload.(T) + if !ok { + var zero T + return zero, fmt.Errorf("unexpected cache payload type %T", payload) + } + return typed, nil +} + +func (h *DashboardHandler) getUsageTrendCached( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + userID, apiKeyID, accountID, groupID int64, + model string, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.TrendDataPoint, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Granularity: granularity, + UserID: userID, + APIKeyID: apiKeyID, + AccountID: accountID, + GroupID: groupID, + Model: model, + RequestType: requestType, + Stream: stream, + BillingType: billingType, + }) + entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) + }) + if err != nil { + return nil, hit, err + } + trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload) + return trend, hit, err +} + +func (h *DashboardHandler) getModelStatsCached( + ctx context.Context, + startTime, endTime time.Time, + userID, apiKeyID, accountID, groupID int64, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.ModelStat, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + UserID: userID, + APIKeyID: apiKeyID, + AccountID: accountID, + GroupID: groupID, + RequestType: requestType, + Stream: stream, + BillingType: billingType, + }) + entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + }) + if err != nil { + return nil, hit, err + } + stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload) + return stats, hit, err +} + +func (h *DashboardHandler) getGroupStatsCached( + ctx context.Context, + startTime, endTime time.Time, + userID, apiKeyID, accountID, groupID int64, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.GroupStat, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + UserID: userID, + APIKeyID: apiKeyID, + AccountID: accountID, + GroupID: groupID, + RequestType: requestType, + Stream: stream, + BillingType: billingType, + }) + entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + }) + if err != nil { + return nil, hit, err + } + stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload) + return stats, hit, err +} + +func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Granularity: granularity, + Limit: limit, + }) + entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit) + }) + if err != nil { + return nil, hit, err + } + trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload) + return trend, hit, err +} + +func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Granularity: granularity, + Limit: limit, + }) + entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit) + }) + if err != nil { + return nil, hit, err + } + trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload) + return trend, hit, err +} diff --git a/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go new file mode 100644 index 00000000..16e10339 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go @@ -0,0 +1,302 @@ +package admin + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +var dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second) + +type dashboardSnapshotV2Stats struct { + usagestats.DashboardStats + Uptime int64 `json:"uptime"` +} + +type dashboardSnapshotV2Response struct { + GeneratedAt string `json:"generated_at"` + + StartDate string `json:"start_date"` + EndDate string `json:"end_date"` + Granularity string `json:"granularity"` + + Stats *dashboardSnapshotV2Stats `json:"stats,omitempty"` + Trend []usagestats.TrendDataPoint `json:"trend,omitempty"` + Models []usagestats.ModelStat `json:"models,omitempty"` + Groups []usagestats.GroupStat `json:"groups,omitempty"` + UsersTrend []usagestats.UserUsageTrendPoint `json:"users_trend,omitempty"` +} + +type dashboardSnapshotV2Filters struct { + UserID int64 + APIKeyID int64 + AccountID int64 + GroupID int64 + Model string + RequestType *int16 + Stream *bool + BillingType *int8 +} + +type dashboardSnapshotV2CacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Granularity string `json:"granularity"` + UserID int64 `json:"user_id"` + APIKeyID int64 `json:"api_key_id"` + AccountID int64 `json:"account_id"` + GroupID int64 `json:"group_id"` + Model string `json:"model"` + RequestType *int16 `json:"request_type"` + Stream *bool `json:"stream"` + BillingType *int8 `json:"billing_type"` + IncludeStats bool `json:"include_stats"` + IncludeTrend bool `json:"include_trend"` + IncludeModels bool `json:"include_models"` + IncludeGroups bool `json:"include_groups"` + IncludeUsersTrend bool `json:"include_users_trend"` + UsersTrendLimit int `json:"users_trend_limit"` +} + +func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) { + startTime, endTime := parseTimeRange(c) + granularity := strings.TrimSpace(c.DefaultQuery("granularity", "day")) + if granularity != "hour" { + granularity = "day" + } + + includeStats := parseBoolQueryWithDefault(c.Query("include_stats"), true) + includeTrend := parseBoolQueryWithDefault(c.Query("include_trend"), true) + includeModels := parseBoolQueryWithDefault(c.Query("include_model_stats"), true) + includeGroups := parseBoolQueryWithDefault(c.Query("include_group_stats"), false) + includeUsersTrend := parseBoolQueryWithDefault(c.Query("include_users_trend"), false) + usersTrendLimit := 12 + if raw := strings.TrimSpace(c.Query("users_trend_limit")); raw != "" { + if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 && parsed <= 50 { + usersTrendLimit = parsed + } + } + + filters, err := parseDashboardSnapshotV2Filters(c) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + keyRaw, _ := json.Marshal(dashboardSnapshotV2CacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Granularity: granularity, + UserID: filters.UserID, + APIKeyID: filters.APIKeyID, + AccountID: filters.AccountID, + GroupID: filters.GroupID, + Model: filters.Model, + RequestType: filters.RequestType, + Stream: filters.Stream, + BillingType: filters.BillingType, + IncludeStats: includeStats, + IncludeTrend: includeTrend, + IncludeModels: includeModels, + IncludeGroups: includeGroups, + IncludeUsersTrend: includeUsersTrend, + UsersTrendLimit: usersTrendLimit, + }) + cacheKey := string(keyRaw) + + cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) { + return h.buildSnapshotV2Response( + c.Request.Context(), + startTime, + endTime, + granularity, + filters, + includeStats, + includeTrend, + includeModels, + includeGroups, + includeUsersTrend, + usersTrendLimit, + ) + }) + if err != nil { + response.Error(c, 500, err.Error()) + return + } + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) { + c.Status(http.StatusNotModified) + return + } + } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) + response.Success(c, cached.Payload) +} + +func (h *DashboardHandler) buildSnapshotV2Response( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + filters *dashboardSnapshotV2Filters, + includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool, + usersTrendLimit int, +) (*dashboardSnapshotV2Response, error) { + resp := &dashboardSnapshotV2Response{ + GeneratedAt: time.Now().UTC().Format(time.RFC3339), + StartDate: startTime.Format("2006-01-02"), + EndDate: endTime.Add(-24 * time.Hour).Format("2006-01-02"), + Granularity: granularity, + } + + if includeStats { + stats, err := h.dashboardService.GetDashboardStats(ctx) + if err != nil { + return nil, errors.New("failed to get dashboard statistics") + } + resp.Stats = &dashboardSnapshotV2Stats{ + DashboardStats: *stats, + Uptime: int64(time.Since(h.startTime).Seconds()), + } + } + + if includeTrend { + trend, _, err := h.getUsageTrendCached( + ctx, + startTime, + endTime, + granularity, + filters.UserID, + filters.APIKeyID, + filters.AccountID, + filters.GroupID, + filters.Model, + filters.RequestType, + filters.Stream, + filters.BillingType, + ) + if err != nil { + return nil, errors.New("failed to get usage trend") + } + resp.Trend = trend + } + + if includeModels { + models, _, err := h.getModelStatsCached( + ctx, + startTime, + endTime, + filters.UserID, + filters.APIKeyID, + filters.AccountID, + filters.GroupID, + filters.RequestType, + filters.Stream, + filters.BillingType, + ) + if err != nil { + return nil, errors.New("failed to get model statistics") + } + resp.Models = models + } + + if includeGroups { + groups, _, err := h.getGroupStatsCached( + ctx, + startTime, + endTime, + filters.UserID, + filters.APIKeyID, + filters.AccountID, + filters.GroupID, + filters.RequestType, + filters.Stream, + filters.BillingType, + ) + if err != nil { + return nil, errors.New("failed to get group statistics") + } + resp.Groups = groups + } + + if includeUsersTrend { + usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit) + if err != nil { + return nil, errors.New("failed to get user usage trend") + } + resp.UsersTrend = usersTrend + } + + return resp, nil +} + +func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) { + filters := &dashboardSnapshotV2Filters{ + Model: strings.TrimSpace(c.Query("model")), + } + + if userIDStr := strings.TrimSpace(c.Query("user_id")); userIDStr != "" { + id, err := strconv.ParseInt(userIDStr, 10, 64) + if err != nil { + return nil, err + } + filters.UserID = id + } + if apiKeyIDStr := strings.TrimSpace(c.Query("api_key_id")); apiKeyIDStr != "" { + id, err := strconv.ParseInt(apiKeyIDStr, 10, 64) + if err != nil { + return nil, err + } + filters.APIKeyID = id + } + if accountIDStr := strings.TrimSpace(c.Query("account_id")); accountIDStr != "" { + id, err := strconv.ParseInt(accountIDStr, 10, 64) + if err != nil { + return nil, err + } + filters.AccountID = id + } + if groupIDStr := strings.TrimSpace(c.Query("group_id")); groupIDStr != "" { + id, err := strconv.ParseInt(groupIDStr, 10, 64) + if err != nil { + return nil, err + } + filters.GroupID = id + } + + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + return nil, err + } + value := int16(parsed) + filters.RequestType = &value + } else if streamStr := strings.TrimSpace(c.Query("stream")); streamStr != "" { + streamVal, err := strconv.ParseBool(streamStr) + if err != nil { + return nil, err + } + filters.Stream = &streamVal + } + + if billingTypeStr := strings.TrimSpace(c.Query("billing_type")); billingTypeStr != "" { + v, err := strconv.ParseInt(billingTypeStr, 10, 8) + if err != nil { + return nil, err + } + bt := int8(v) + filters.BillingType = &bt + } + + return filters, nil +} diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 1edf4dcc..4ffe64ee 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -1,6 +1,9 @@ package admin import ( + "bytes" + "encoding/json" + "fmt" "strconv" "strings" @@ -16,6 +19,55 @@ type GroupHandler struct { adminService service.AdminService } +type optionalLimitField struct { + set bool + value *float64 +} + +func (f *optionalLimitField) UnmarshalJSON(data []byte) error { + f.set = true + + trimmed := bytes.TrimSpace(data) + if bytes.Equal(trimmed, []byte("null")) { + f.value = nil + return nil + } + + var number float64 + if err := json.Unmarshal(trimmed, &number); err == nil { + f.value = &number + return nil + } + + var text string + if err := json.Unmarshal(trimmed, &text); err == nil { + text = strings.TrimSpace(text) + if text == "" { + f.value = nil + return nil + } + number, err = strconv.ParseFloat(text, 64) + if err != nil { + return fmt.Errorf("invalid numeric limit value %q: %w", text, err) + } + f.value = &number + return nil + } + + return fmt.Errorf("invalid limit value: %s", string(trimmed)) +} + +func (f optionalLimitField) ToServiceInput() *float64 { + if !f.set { + return nil + } + if f.value != nil { + return f.value + } + zero := 0.0 + return &zero +} + // NewGroupHandler creates a new admin group handler func NewGroupHandler(adminService service.AdminService) *GroupHandler { return &GroupHandler{ @@ -25,15 +77,15 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler { // CreateGroupRequest represents create group request type CreateGroupRequest struct { - Name string `json:"name" binding:"required"` - Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` - RateMultiplier float64 `json:"rate_multiplier"` - IsExclusive bool `json:"is_exclusive"` - SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` - DailyLimitUSD *float64 `json:"daily_limit_usd"` - WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` - MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` + Name string `json:"name" binding:"required"` + Description string `json:"description"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` + RateMultiplier float64 `json:"rate_multiplier"` + IsExclusive bool `json:"is_exclusive"` + SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` + DailyLimitUSD optionalLimitField `json:"daily_limit_usd"` + WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"` + MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` @@ -53,22 +105,25 @@ type CreateGroupRequest struct { SupportedModelScopes []string `json:"supported_model_scopes"` // Sora 存储配额 SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + DefaultMappedModel string `json:"default_mapped_model"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } // UpdateGroupRequest represents update group request type UpdateGroupRequest struct { - Name string `json:"name"` - Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` - RateMultiplier *float64 `json:"rate_multiplier"` - IsExclusive *bool `json:"is_exclusive"` - Status string `json:"status" binding:"omitempty,oneof=active inactive"` - SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` - DailyLimitUSD *float64 `json:"daily_limit_usd"` - WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` - MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` + Name string `json:"name"` + Description string `json:"description"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` + RateMultiplier *float64 `json:"rate_multiplier"` + IsExclusive *bool `json:"is_exclusive"` + Status string `json:"status" binding:"omitempty,oneof=active inactive"` + SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` + DailyLimitUSD optionalLimitField `json:"daily_limit_usd"` + WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"` + MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` @@ -88,6 +143,9 @@ type UpdateGroupRequest struct { SupportedModelScopes *[]string `json:"supported_model_scopes"` // Sora 存储配额 SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` + DefaultMappedModel *string `json:"default_mapped_model"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -185,9 +243,9 @@ func (h *GroupHandler) Create(c *gin.Context) { RateMultiplier: req.RateMultiplier, IsExclusive: req.IsExclusive, SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, + DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(), + WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(), + MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(), ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, @@ -203,6 +261,8 @@ func (h *GroupHandler) Create(c *gin.Context) { MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, + AllowMessagesDispatch: req.AllowMessagesDispatch, + DefaultMappedModel: req.DefaultMappedModel, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -236,9 +296,9 @@ func (h *GroupHandler) Update(c *gin.Context) { IsExclusive: req.IsExclusive, Status: req.Status, SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, + DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(), + WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(), + MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(), ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, @@ -254,6 +314,8 @@ func (h *GroupHandler) Update(c *gin.Context) { MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, + AllowMessagesDispatch: req.AllowMessagesDispatch, + DefaultMappedModel: req.DefaultMappedModel, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -325,6 +387,72 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) { response.Paginated(c, outKeys, total, page, pageSize) } +// GetGroupRateMultipliers handles getting rate multipliers for users in a group +// GET /api/v1/admin/groups/:id/rate-multipliers +func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) { + groupID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid group ID") + return + } + + entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + if entries == nil { + entries = []service.UserGroupRateEntry{} + } + response.Success(c, entries) +} + +// ClearGroupRateMultipliers handles clearing all rate multipliers for a group +// DELETE /api/v1/admin/groups/:id/rate-multipliers +func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) { + groupID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid group ID") + return + } + + if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"}) +} + +// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request +type BatchSetGroupRateMultipliersRequest struct { + Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"` +} + +// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group +// PUT /api/v1/admin/groups/:id/rate-multipliers +func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) { + groupID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid group ID") + return + } + + var req BatchSetGroupRateMultipliersRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Rate multipliers updated successfully"}) +} + // UpdateSortOrderRequest represents the request to update group sort orders type UpdateSortOrderRequest struct { Updates []struct { diff --git a/backend/internal/handler/admin/id_list_utils.go b/backend/internal/handler/admin/id_list_utils.go new file mode 100644 index 00000000..2aeefe38 --- /dev/null +++ b/backend/internal/handler/admin/id_list_utils.go @@ -0,0 +1,25 @@ +package admin + +import "sort" + +func normalizeInt64IDList(ids []int64) []int64 { + if len(ids) == 0 { + return nil + } + + out := make([]int64, 0, len(ids)) + seen := make(map[int64]struct{}, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + + sort.Slice(out, func(i, j int) bool { return out[i] < out[j] }) + return out +} diff --git a/backend/internal/handler/admin/id_list_utils_test.go b/backend/internal/handler/admin/id_list_utils_test.go new file mode 100644 index 00000000..aa65d5c0 --- /dev/null +++ b/backend/internal/handler/admin/id_list_utils_test.go @@ -0,0 +1,57 @@ +//go:build unit + +package admin + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNormalizeInt64IDList(t *testing.T) { + tests := []struct { + name string + in []int64 + want []int64 + }{ + {"nil input", nil, nil}, + {"empty input", []int64{}, nil}, + {"single element", []int64{5}, []int64{5}}, + {"already sorted unique", []int64{1, 2, 3}, []int64{1, 2, 3}}, + {"duplicates removed", []int64{3, 1, 3, 2, 1}, []int64{1, 2, 3}}, + {"zero filtered", []int64{0, 1, 2}, []int64{1, 2}}, + {"negative filtered", []int64{-5, -1, 3}, []int64{3}}, + {"all invalid", []int64{0, -1, -2}, []int64{}}, + {"sorted output", []int64{9, 3, 7, 1}, []int64{1, 3, 7, 9}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := normalizeInt64IDList(tc.in) + if tc.want == nil { + require.Nil(t, got) + } else { + require.Equal(t, tc.want, got) + } + }) + } +} + +func TestBuildAccountTodayStatsBatchCacheKey(t *testing.T) { + tests := []struct { + name string + ids []int64 + want string + }{ + {"empty", nil, "accounts_today_stats_empty"}, + {"single", []int64{42}, "accounts_today_stats:42"}, + {"multiple", []int64{1, 2, 3}, "accounts_today_stats:1,2,3"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := buildAccountTodayStatsBatchCacheKey(tc.ids) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go index 5d354fd3..4e6179db 100644 --- a/backend/internal/handler/admin/openai_oauth_handler.go +++ b/backend/internal/handler/admin/openai_oauth_handler.go @@ -289,6 +289,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { Platform: platform, Type: "oauth", Credentials: credentials, + Extra: nil, ProxyID: req.ProxyID, Concurrency: req.Concurrency, Priority: req.Priority, diff --git a/backend/internal/handler/admin/ops_alerts_handler.go b/backend/internal/handler/admin/ops_alerts_handler.go index c9da19c7..edc8c7f7 100644 --- a/backend/internal/handler/admin/ops_alerts_handler.go +++ b/backend/internal/handler/admin/ops_alerts_handler.go @@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{ "cpu_usage_percent", "memory_usage_percent", "concurrency_queue_depth", + "group_available_accounts", + "group_available_ratio", + "group_rate_limit_ratio", + "account_rate_limited_count", + "account_error_count", + "account_error_ratio", + "overload_account_count", } var validOpsAlertMetricTypeSet = func() map[string]struct{} { @@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool { "error_rate", "upstream_error_rate", "cpu_usage_percent", - "memory_usage_percent": + "memory_usage_percent", + "group_available_ratio", + "group_rate_limit_ratio", + "account_error_ratio": return true default: return false diff --git a/backend/internal/handler/admin/ops_snapshot_v2_handler.go b/backend/internal/handler/admin/ops_snapshot_v2_handler.go new file mode 100644 index 00000000..5cac00fe --- /dev/null +++ b/backend/internal/handler/admin/ops_snapshot_v2_handler.go @@ -0,0 +1,145 @@ +package admin + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "golang.org/x/sync/errgroup" +) + +var opsDashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second) + +type opsDashboardSnapshotV2Response struct { + GeneratedAt string `json:"generated_at"` + + Overview *service.OpsDashboardOverview `json:"overview"` + ThroughputTrend *service.OpsThroughputTrendResponse `json:"throughput_trend"` + ErrorTrend *service.OpsErrorTrendResponse `json:"error_trend"` +} + +type opsDashboardSnapshotV2CacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Platform string `json:"platform"` + GroupID *int64 `json:"group_id"` + QueryMode service.OpsQueryMode `json:"mode"` + BucketSecond int `json:"bucket_second"` +} + +// GetDashboardSnapshotV2 returns ops dashboard core snapshot in one request. +// GET /api/v1/admin/ops/dashboard/snapshot-v2 +func (h *OpsHandler) GetDashboardSnapshotV2(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + startTime, endTime, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsDashboardFilter{ + StartTime: startTime, + EndTime: endTime, + Platform: strings.TrimSpace(c.Query("platform")), + QueryMode: parseOpsQueryMode(c), + } + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &id + } + bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime)) + + keyRaw, _ := json.Marshal(opsDashboardSnapshotV2CacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Platform: filter.Platform, + GroupID: filter.GroupID, + QueryMode: filter.QueryMode, + BucketSecond: bucketSeconds, + }) + cacheKey := string(keyRaw) + + if cached, ok := opsDashboardSnapshotV2Cache.Get(cacheKey); ok { + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) { + c.Status(http.StatusNotModified) + return + } + } + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + var ( + overview *service.OpsDashboardOverview + trend *service.OpsThroughputTrendResponse + errTrend *service.OpsErrorTrendResponse + ) + g, gctx := errgroup.WithContext(c.Request.Context()) + g.Go(func() error { + f := *filter + result, err := h.opsService.GetDashboardOverview(gctx, &f) + if err != nil { + return err + } + overview = result + return nil + }) + g.Go(func() error { + f := *filter + result, err := h.opsService.GetThroughputTrend(gctx, &f, bucketSeconds) + if err != nil { + return err + } + trend = result + return nil + }) + g.Go(func() error { + f := *filter + result, err := h.opsService.GetErrorTrend(gctx, &f, bucketSeconds) + if err != nil { + return err + } + errTrend = result + return nil + }) + if err := g.Wait(); err != nil { + response.ErrorFrom(c, err) + return + } + + resp := &opsDashboardSnapshotV2Response{ + GeneratedAt: time.Now().UTC().Format(time.RFC3339), + Overview: overview, + ThroughputTrend: trend, + ErrorTrend: errTrend, + } + + cached := opsDashboardSnapshotV2Cache.Set(cacheKey, resp) + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + } + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, resp) +} diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 0a932ee9..13ea88d9 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct { } // CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user. +// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。 type CreateAndRedeemCodeRequest struct { - Code string `json:"code" binding:"required,min=3,max=128"` - Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"` - Value float64 `json:"value" binding:"required,gt=0"` - UserID int64 `json:"user_id" binding:"required,gt=0"` - Notes string `json:"notes"` + Code string `json:"code" binding:"required,min=3,max=128"` + Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容) + Value float64 `json:"value" binding:"required,gt=0"` + UserID int64 `json:"user_id" binding:"required,gt=0"` + GroupID *int64 `json:"group_id"` // subscription 类型必填 + ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0 + Notes string `json:"notes"` } // List handles listing all redeem codes with pagination @@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) { return } req.Code = strings.TrimSpace(req.Code) + // 向后兼容:旧版调用方(如 Sub2ApiPay)不传 type 字段,默认当作 balance 充值处理。 + // 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。 + if req.Type == "" { + req.Type = "balance" + } + + if req.Type == "subscription" { + if req.GroupID == nil { + response.BadRequest(c, "group_id is required for subscription type") + return + } + if req.ValidityDays <= 0 { + response.BadRequest(c, "validity_days must be greater than 0 for subscription type") + return + } + } executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { existing, err := h.redeemService.GetByCode(ctx, req.Code) @@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) { } createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{ - Code: req.Code, - Type: req.Type, - Value: req.Value, - Status: service.StatusUnused, - Notes: req.Notes, + Code: req.Code, + Type: req.Type, + Value: req.Value, + Status: service.StatusUnused, + Notes: req.Notes, + GroupID: req.GroupID, + ValidityDays: req.ValidityDays, }) if createErr != nil { // Unique code race: if code now exists, use idempotent semantics by used_by. diff --git a/backend/internal/handler/admin/redeem_handler_test.go b/backend/internal/handler/admin/redeem_handler_test.go new file mode 100644 index 00000000..0d42f64f --- /dev/null +++ b/backend/internal/handler/admin/redeem_handler_test.go @@ -0,0 +1,135 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal) +// RedeemService so that CreateAndRedeem's nil guard passes and we can test the +// parameter-validation layer that runs before any service call. +func newCreateAndRedeemHandler() *RedeemHandler { + return &RedeemHandler{ + adminService: newStubAdminService(), + redeemService: &service.RedeemService{}, // non-nil to pass nil guard + } +} + +// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response +// status code. For cases that pass validation and proceed into the service layer, +// a panic may occur (because RedeemService internals are nil); this is expected +// and treated as "validation passed" (returns 0 to indicate panic). +func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + jsonBytes, err := json.Marshal(body) + require.NoError(t, err) + c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes)) + c.Request.Header.Set("Content-Type", "application/json") + + defer func() { + if r := recover(); r != nil { + // Panic means we passed validation and entered service layer (expected for minimal stub). + code = 0 + } + }() + handler.CreateAndRedeem(c) + return w.Code +} + +func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) { + // 不传 type 字段时应默认 balance,不触发 subscription 校验。 + // 验证通过后进入 service 层会 panic(返回 0),说明默认值生效。 + h := newCreateAndRedeemHandler() + code := postCreateAndRedeemValidation(t, h, map[string]any{ + "code": "test-balance-default", + "value": 10.0, + "user_id": 1, + }) + + assert.NotEqual(t, http.StatusBadRequest, code, + "omitting type should default to balance and pass validation") +} + +func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) { + h := newCreateAndRedeemHandler() + code := postCreateAndRedeemValidation(t, h, map[string]any{ + "code": "test-sub-no-group", + "type": "subscription", + "value": 29.9, + "user_id": 1, + "validity_days": 30, + // group_id 缺失 + }) + + assert.Equal(t, http.StatusBadRequest, code) +} + +func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) { + groupID := int64(5) + h := newCreateAndRedeemHandler() + + cases := []struct { + name string + validityDays int + }{ + {"zero", 0}, + {"negative", -1}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + code := postCreateAndRedeemValidation(t, h, map[string]any{ + "code": "test-sub-bad-days-" + tc.name, + "type": "subscription", + "value": 29.9, + "user_id": 1, + "group_id": groupID, + "validity_days": tc.validityDays, + }) + + assert.Equal(t, http.StatusBadRequest, code) + }) + } +} + +func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) { + groupID := int64(5) + h := newCreateAndRedeemHandler() + code := postCreateAndRedeemValidation(t, h, map[string]any{ + "code": "test-sub-valid", + "type": "subscription", + "value": 29.9, + "user_id": 1, + "group_id": groupID, + "validity_days": 31, + }) + + assert.NotEqual(t, http.StatusBadRequest, code, + "valid subscription params should pass validation") +} + +func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) { + h := newCreateAndRedeemHandler() + // balance 类型不传 group_id 和 validity_days,不应报 400 + code := postCreateAndRedeemValidation(t, h, map[string]any{ + "code": "test-balance-no-extras", + "type": "balance", + "value": 50.0, + "user_id": 1, + }) + + assert.NotEqual(t, http.StatusBadRequest, code, + "balance type should not require group_id or validity_days") +} diff --git a/backend/internal/handler/admin/scheduled_test_handler.go b/backend/internal/handler/admin/scheduled_test_handler.go new file mode 100644 index 00000000..d9f39737 --- /dev/null +++ b/backend/internal/handler/admin/scheduled_test_handler.go @@ -0,0 +1,163 @@ +package admin + +import ( + "net/http" + "strconv" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// ScheduledTestHandler handles admin scheduled-test-plan management. +type ScheduledTestHandler struct { + scheduledTestSvc *service.ScheduledTestService +} + +// NewScheduledTestHandler creates a new ScheduledTestHandler. +func NewScheduledTestHandler(scheduledTestSvc *service.ScheduledTestService) *ScheduledTestHandler { + return &ScheduledTestHandler{scheduledTestSvc: scheduledTestSvc} +} + +type createScheduledTestPlanRequest struct { + AccountID int64 `json:"account_id" binding:"required"` + ModelID string `json:"model_id"` + CronExpression string `json:"cron_expression" binding:"required"` + Enabled *bool `json:"enabled"` + MaxResults int `json:"max_results"` + AutoRecover *bool `json:"auto_recover"` +} + +type updateScheduledTestPlanRequest struct { + ModelID string `json:"model_id"` + CronExpression string `json:"cron_expression"` + Enabled *bool `json:"enabled"` + MaxResults int `json:"max_results"` + AutoRecover *bool `json:"auto_recover"` +} + +// ListByAccount GET /admin/accounts/:id/scheduled-test-plans +func (h *ScheduledTestHandler) ListByAccount(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "invalid account id") + return + } + + plans, err := h.scheduledTestSvc.ListPlansByAccount(c.Request.Context(), accountID) + if err != nil { + response.InternalError(c, err.Error()) + return + } + c.JSON(http.StatusOK, plans) +} + +// Create POST /admin/scheduled-test-plans +func (h *ScheduledTestHandler) Create(c *gin.Context) { + var req createScheduledTestPlanRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + plan := &service.ScheduledTestPlan{ + AccountID: req.AccountID, + ModelID: req.ModelID, + CronExpression: req.CronExpression, + Enabled: true, + MaxResults: req.MaxResults, + } + if req.Enabled != nil { + plan.Enabled = *req.Enabled + } + if req.AutoRecover != nil { + plan.AutoRecover = *req.AutoRecover + } + + created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + c.JSON(http.StatusOK, created) +} + +// Update PUT /admin/scheduled-test-plans/:id +func (h *ScheduledTestHandler) Update(c *gin.Context) { + planID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "invalid plan id") + return + } + + existing, err := h.scheduledTestSvc.GetPlan(c.Request.Context(), planID) + if err != nil { + response.NotFound(c, "plan not found") + return + } + + var req updateScheduledTestPlanRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if req.ModelID != "" { + existing.ModelID = req.ModelID + } + if req.CronExpression != "" { + existing.CronExpression = req.CronExpression + } + if req.Enabled != nil { + existing.Enabled = *req.Enabled + } + if req.MaxResults > 0 { + existing.MaxResults = req.MaxResults + } + if req.AutoRecover != nil { + existing.AutoRecover = *req.AutoRecover + } + + updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + c.JSON(http.StatusOK, updated) +} + +// Delete DELETE /admin/scheduled-test-plans/:id +func (h *ScheduledTestHandler) Delete(c *gin.Context) { + planID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "invalid plan id") + return + } + + if err := h.scheduledTestSvc.DeletePlan(c.Request.Context(), planID); err != nil { + response.InternalError(c, err.Error()) + return + } + c.JSON(http.StatusOK, gin.H{"message": "deleted"}) +} + +// ListResults GET /admin/scheduled-test-plans/:id/results +func (h *ScheduledTestHandler) ListResults(c *gin.Context) { + planID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "invalid plan id") + return + } + + limit := 50 + if l, err := strconv.Atoi(c.Query("limit")); err == nil && l > 0 { + limit = l + } + + results, err := h.scheduledTestSvc.ListResults(c.Request.Context(), planID, limit) + if err != nil { + response.InternalError(c, err.Error()) + return + } + c.JSON(http.StatusOK, results) +} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index e7da042c..c966cb7d 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1,6 +1,9 @@ package admin import ( + "crypto/rand" + "encoding/hex" + "encoding/json" "fmt" "log" "net/http" @@ -20,6 +23,18 @@ import ( // semverPattern 预编译 semver 格式校验正则 var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`) +// menuItemIDPattern validates custom menu item IDs: alphanumeric, hyphens, underscores only. +var menuItemIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + +// generateMenuItemID generates a short random hex ID for a custom menu item. +func generateMenuItemID() (string, error) { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate menu item ID: %w", err) + } + return hex.EncodeToString(b), nil +} + // SettingHandler 系统设置处理器 type SettingHandler struct { settingService *service.SettingService @@ -62,8 +77,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { response.Success(c, dto.SystemSettings{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, PromoCodeEnabled: settings.PromoCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled, + FrontendURL: settings.FrontendURL, InvitationCodeEnabled: settings.InvitationCodeEnabled, TotpEnabled: settings.TotpEnabled, TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), @@ -92,6 +109,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, + CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, DefaultSubscriptions: defaultSubscriptions, @@ -107,18 +125,22 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { OpsQueryModeDefault: settings.OpsQueryModeDefault, OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, MinClaudeCodeVersion: settings.MinClaudeCodeVersion, + AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling, + BackendModeEnabled: settings.BackendModeEnabled, }) } // UpdateSettingsRequest 更新设置请求 type UpdateSettingsRequest struct { // 注册设置 - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + FrontendURL string `json:"frontend_url"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 // 邮件服务设置 SMTPHost string `json:"smtp_host"` @@ -141,17 +163,18 @@ type UpdateSettingsRequest struct { LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` // OEM设置 - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` - SoraClientEnabled bool `json:"sora_client_enabled"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` // 默认配置 DefaultConcurrency int `json:"default_concurrency"` @@ -176,6 +199,12 @@ type UpdateSettingsRequest struct { OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"` MinClaudeCodeVersion string `json:"min_claude_code_version"` + + // 分组隔离 + AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"` + + // Backend Mode + BackendModeEnabled bool `json:"backend_mode_enabled"` } // UpdateSettings 更新系统设置 @@ -299,6 +328,93 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + // Frontend URL 验证 + req.FrontendURL = strings.TrimSpace(req.FrontendURL) + if req.FrontendURL != "" { + if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil { + response.BadRequest(c, "Frontend URL must be an absolute http(s) URL") + return + } + } + + // 自定义菜单项验证 + const ( + maxCustomMenuItems = 20 + maxMenuItemLabelLen = 50 + maxMenuItemURLLen = 2048 + maxMenuItemIconSVGLen = 10 * 1024 // 10KB + maxMenuItemIDLen = 32 + ) + + customMenuJSON := previousSettings.CustomMenuItems + if req.CustomMenuItems != nil { + items := *req.CustomMenuItems + if len(items) > maxCustomMenuItems { + response.BadRequest(c, "Too many custom menu items (max 20)") + return + } + for i, item := range items { + if strings.TrimSpace(item.Label) == "" { + response.BadRequest(c, "Custom menu item label is required") + return + } + if len(item.Label) > maxMenuItemLabelLen { + response.BadRequest(c, "Custom menu item label is too long (max 50 characters)") + return + } + if strings.TrimSpace(item.URL) == "" { + response.BadRequest(c, "Custom menu item URL is required") + return + } + if len(item.URL) > maxMenuItemURLLen { + response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)") + return + } + if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil { + response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL") + return + } + if item.Visibility != "user" && item.Visibility != "admin" { + response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'") + return + } + if len(item.IconSVG) > maxMenuItemIconSVGLen { + response.BadRequest(c, "Custom menu item icon SVG is too large (max 10KB)") + return + } + // Auto-generate ID if missing + if strings.TrimSpace(item.ID) == "" { + id, err := generateMenuItemID() + if err != nil { + response.Error(c, http.StatusInternalServerError, "Failed to generate menu item ID") + return + } + items[i].ID = id + } else if len(item.ID) > maxMenuItemIDLen { + response.BadRequest(c, "Custom menu item ID is too long (max 32 characters)") + return + } else if !menuItemIDPattern.MatchString(item.ID) { + response.BadRequest(c, "Custom menu item ID contains invalid characters (only a-z, A-Z, 0-9, - and _ are allowed)") + return + } + } + // ID uniqueness check + seen := make(map[string]struct{}, len(items)) + for _, item := range items { + if _, exists := seen[item.ID]; exists { + response.BadRequest(c, "Duplicate custom menu item ID: "+item.ID) + return + } + seen[item.ID] = struct{}{} + } + menuBytes, err := json.Marshal(items) + if err != nil { + response.BadRequest(c, "Failed to serialize custom menu items") + return + } + customMenuJSON = string(menuBytes) + } + // Ops metrics collector interval validation (seconds). if req.OpsMetricsIntervalSeconds != nil { v := *req.OpsMetricsIntervalSeconds @@ -327,48 +443,53 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } settings := &service.SystemSettings{ - RegistrationEnabled: req.RegistrationEnabled, - EmailVerifyEnabled: req.EmailVerifyEnabled, - PromoCodeEnabled: req.PromoCodeEnabled, - PasswordResetEnabled: req.PasswordResetEnabled, - InvitationCodeEnabled: req.InvitationCodeEnabled, - TotpEnabled: req.TotpEnabled, - SMTPHost: req.SMTPHost, - SMTPPort: req.SMTPPort, - SMTPUsername: req.SMTPUsername, - SMTPPassword: req.SMTPPassword, - SMTPFrom: req.SMTPFrom, - SMTPFromName: req.SMTPFromName, - SMTPUseTLS: req.SMTPUseTLS, - TurnstileEnabled: req.TurnstileEnabled, - TurnstileSiteKey: req.TurnstileSiteKey, - TurnstileSecretKey: req.TurnstileSecretKey, - LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, - LinuxDoConnectClientID: req.LinuxDoConnectClientID, - LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, - LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, - SiteName: req.SiteName, - SiteLogo: req.SiteLogo, - SiteSubtitle: req.SiteSubtitle, - APIBaseURL: req.APIBaseURL, - ContactInfo: req.ContactInfo, - DocURL: req.DocURL, - HomeContent: req.HomeContent, - HideCcsImportButton: req.HideCcsImportButton, - PurchaseSubscriptionEnabled: purchaseEnabled, - PurchaseSubscriptionURL: purchaseURL, - SoraClientEnabled: req.SoraClientEnabled, - DefaultConcurrency: req.DefaultConcurrency, - DefaultBalance: req.DefaultBalance, - DefaultSubscriptions: defaultSubscriptions, - EnableModelFallback: req.EnableModelFallback, - FallbackModelAnthropic: req.FallbackModelAnthropic, - FallbackModelOpenAI: req.FallbackModelOpenAI, - FallbackModelGemini: req.FallbackModelGemini, - FallbackModelAntigravity: req.FallbackModelAntigravity, - EnableIdentityPatch: req.EnableIdentityPatch, - IdentityPatchPrompt: req.IdentityPatchPrompt, - MinClaudeCodeVersion: req.MinClaudeCodeVersion, + RegistrationEnabled: req.RegistrationEnabled, + EmailVerifyEnabled: req.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: req.PromoCodeEnabled, + PasswordResetEnabled: req.PasswordResetEnabled, + FrontendURL: req.FrontendURL, + InvitationCodeEnabled: req.InvitationCodeEnabled, + TotpEnabled: req.TotpEnabled, + SMTPHost: req.SMTPHost, + SMTPPort: req.SMTPPort, + SMTPUsername: req.SMTPUsername, + SMTPPassword: req.SMTPPassword, + SMTPFrom: req.SMTPFrom, + SMTPFromName: req.SMTPFromName, + SMTPUseTLS: req.SMTPUseTLS, + TurnstileEnabled: req.TurnstileEnabled, + TurnstileSiteKey: req.TurnstileSiteKey, + TurnstileSecretKey: req.TurnstileSecretKey, + LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, + LinuxDoConnectClientID: req.LinuxDoConnectClientID, + LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, + LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + SiteName: req.SiteName, + SiteLogo: req.SiteLogo, + SiteSubtitle: req.SiteSubtitle, + APIBaseURL: req.APIBaseURL, + ContactInfo: req.ContactInfo, + DocURL: req.DocURL, + HomeContent: req.HomeContent, + HideCcsImportButton: req.HideCcsImportButton, + PurchaseSubscriptionEnabled: purchaseEnabled, + PurchaseSubscriptionURL: purchaseURL, + SoraClientEnabled: req.SoraClientEnabled, + CustomMenuItems: customMenuJSON, + DefaultConcurrency: req.DefaultConcurrency, + DefaultBalance: req.DefaultBalance, + DefaultSubscriptions: defaultSubscriptions, + EnableModelFallback: req.EnableModelFallback, + FallbackModelAnthropic: req.FallbackModelAnthropic, + FallbackModelOpenAI: req.FallbackModelOpenAI, + FallbackModelGemini: req.FallbackModelGemini, + FallbackModelAntigravity: req.FallbackModelAntigravity, + EnableIdentityPatch: req.EnableIdentityPatch, + IdentityPatchPrompt: req.IdentityPatchPrompt, + MinClaudeCodeVersion: req.MinClaudeCodeVersion, + AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling, + BackendModeEnabled: req.BackendModeEnabled, OpsMonitoringEnabled: func() bool { if req.OpsMonitoringEnabled != nil { return *req.OpsMonitoringEnabled @@ -419,8 +540,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.Success(c, dto.SystemSettings{ RegistrationEnabled: updatedSettings.RegistrationEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist, PromoCodeEnabled: updatedSettings.PromoCodeEnabled, PasswordResetEnabled: updatedSettings.PasswordResetEnabled, + FrontendURL: updatedSettings.FrontendURL, InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled, TotpEnabled: updatedSettings.TotpEnabled, TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), @@ -449,6 +572,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, SoraClientEnabled: updatedSettings.SoraClientEnabled, + CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, DefaultSubscriptions: updatedDefaultSubscriptions, @@ -464,6 +588,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, + AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling, + BackendModeEnabled: updatedSettings.BackendModeEnabled, }) } @@ -495,9 +621,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.EmailVerifyEnabled != after.EmailVerifyEnabled { changed = append(changed, "email_verify_enabled") } + if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) { + changed = append(changed, "registration_email_suffix_whitelist") + } if before.PasswordResetEnabled != after.PasswordResetEnabled { changed = append(changed, "password_reset_enabled") } + if before.FrontendURL != after.FrontendURL { + changed = append(changed, "frontend_url") + } if before.TotpEnabled != after.TotpEnabled { changed = append(changed, "totp_enabled") } @@ -612,6 +744,21 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion { changed = append(changed, "min_claude_code_version") } + if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling { + changed = append(changed, "allow_ungrouped_key_scheduling") + } + if before.BackendModeEnabled != after.BackendModeEnabled { + changed = append(changed, "backend_mode_enabled") + } + if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled { + changed = append(changed, "purchase_subscription_enabled") + } + if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL { + changed = append(changed, "purchase_subscription_url") + } + if before.CustomMenuItems != after.CustomMenuItems { + changed = append(changed, "custom_menu_items") + } return changed } @@ -632,6 +779,18 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto return normalized } +func equalStringSlice(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { if len(a) != len(b) { return false @@ -685,7 +844,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) { err := h.emailService.TestSMTPConnectionWithConfig(config) if err != nil { - response.ErrorFrom(c, err) + response.BadRequest(c, "SMTP connection test failed: "+err.Error()) return } @@ -771,7 +930,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) { ` if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil { - response.ErrorFrom(c, err) + response.BadRequest(c, "Failed to send test email: "+err.Error()) return } @@ -1214,6 +1373,118 @@ func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) { response.Success(c, gin.H{"message": "S3 连接成功"}) } +// GetRectifierSettings 获取请求整流器配置 +// GET /api/v1/admin/settings/rectifier +func (h *SettingHandler) GetRectifierSettings(c *gin.Context) { + settings, err := h.settingService.GetRectifierSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.RectifierSettings{ + Enabled: settings.Enabled, + ThinkingSignatureEnabled: settings.ThinkingSignatureEnabled, + ThinkingBudgetEnabled: settings.ThinkingBudgetEnabled, + }) +} + +// UpdateRectifierSettingsRequest 更新整流器配置请求 +type UpdateRectifierSettingsRequest struct { + Enabled bool `json:"enabled"` + ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` + ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` +} + +// UpdateRectifierSettings 更新请求整流器配置 +// PUT /api/v1/admin/settings/rectifier +func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) { + var req UpdateRectifierSettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + settings := &service.RectifierSettings{ + Enabled: req.Enabled, + ThinkingSignatureEnabled: req.ThinkingSignatureEnabled, + ThinkingBudgetEnabled: req.ThinkingBudgetEnabled, + } + + if err := h.settingService.SetRectifierSettings(c.Request.Context(), settings); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 重新获取设置返回 + updatedSettings, err := h.settingService.GetRectifierSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.RectifierSettings{ + Enabled: updatedSettings.Enabled, + ThinkingSignatureEnabled: updatedSettings.ThinkingSignatureEnabled, + ThinkingBudgetEnabled: updatedSettings.ThinkingBudgetEnabled, + }) +} + +// GetBetaPolicySettings 获取 Beta 策略配置 +// GET /api/v1/admin/settings/beta-policy +func (h *SettingHandler) GetBetaPolicySettings(c *gin.Context) { + settings, err := h.settingService.GetBetaPolicySettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + rules := make([]dto.BetaPolicyRule, len(settings.Rules)) + for i, r := range settings.Rules { + rules[i] = dto.BetaPolicyRule(r) + } + response.Success(c, dto.BetaPolicySettings{Rules: rules}) +} + +// UpdateBetaPolicySettingsRequest 更新 Beta 策略配置请求 +type UpdateBetaPolicySettingsRequest struct { + Rules []dto.BetaPolicyRule `json:"rules"` +} + +// UpdateBetaPolicySettings 更新 Beta 策略配置 +// PUT /api/v1/admin/settings/beta-policy +func (h *SettingHandler) UpdateBetaPolicySettings(c *gin.Context) { + var req UpdateBetaPolicySettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + rules := make([]service.BetaPolicyRule, len(req.Rules)) + for i, r := range req.Rules { + rules[i] = service.BetaPolicyRule(r) + } + + settings := &service.BetaPolicySettings{Rules: rules} + if err := h.settingService.SetBetaPolicySettings(c.Request.Context(), settings); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // Re-fetch to return updated settings + updated, err := h.settingService.GetBetaPolicySettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + outRules := make([]dto.BetaPolicyRule, len(updated.Rules)) + for i, r := range updated.Rules { + outRules[i] = dto.BetaPolicyRule(r) + } + response.Success(c, dto.BetaPolicySettings{Rules: outRules}) +} + // UpdateStreamTimeoutSettingsRequest 更新流超时配置请求 type UpdateStreamTimeoutSettingsRequest struct { Enabled bool `json:"enabled"` diff --git a/backend/internal/handler/admin/snapshot_cache.go b/backend/internal/handler/admin/snapshot_cache.go new file mode 100644 index 00000000..d6973ff9 --- /dev/null +++ b/backend/internal/handler/admin/snapshot_cache.go @@ -0,0 +1,138 @@ +package admin + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "strings" + "sync" + "time" + + "golang.org/x/sync/singleflight" +) + +type snapshotCacheEntry struct { + ETag string + Payload any + ExpiresAt time.Time +} + +type snapshotCache struct { + mu sync.RWMutex + ttl time.Duration + items map[string]snapshotCacheEntry + sf singleflight.Group +} + +type snapshotCacheLoadResult struct { + Entry snapshotCacheEntry + Hit bool +} + +func newSnapshotCache(ttl time.Duration) *snapshotCache { + if ttl <= 0 { + ttl = 30 * time.Second + } + return &snapshotCache{ + ttl: ttl, + items: make(map[string]snapshotCacheEntry), + } +} + +func (c *snapshotCache) Get(key string) (snapshotCacheEntry, bool) { + if c == nil || key == "" { + return snapshotCacheEntry{}, false + } + now := time.Now() + + c.mu.RLock() + entry, ok := c.items[key] + c.mu.RUnlock() + if !ok { + return snapshotCacheEntry{}, false + } + if now.After(entry.ExpiresAt) { + c.mu.Lock() + delete(c.items, key) + c.mu.Unlock() + return snapshotCacheEntry{}, false + } + return entry, true +} + +func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry { + if c == nil { + return snapshotCacheEntry{} + } + entry := snapshotCacheEntry{ + ETag: buildETagFromAny(payload), + Payload: payload, + ExpiresAt: time.Now().Add(c.ttl), + } + if key == "" { + return entry + } + c.mu.Lock() + c.items[key] = entry + c.mu.Unlock() + return entry +} + +func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) { + if load == nil { + return snapshotCacheEntry{}, false, nil + } + if entry, ok := c.Get(key); ok { + return entry, true, nil + } + if c == nil || key == "" { + payload, err := load() + if err != nil { + return snapshotCacheEntry{}, false, err + } + return c.Set(key, payload), false, nil + } + + value, err, _ := c.sf.Do(key, func() (any, error) { + if entry, ok := c.Get(key); ok { + return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil + } + payload, err := load() + if err != nil { + return nil, err + } + return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil + }) + if err != nil { + return snapshotCacheEntry{}, false, err + } + result, ok := value.(snapshotCacheLoadResult) + if !ok { + return snapshotCacheEntry{}, false, nil + } + return result.Entry, result.Hit, nil +} + +func buildETagFromAny(payload any) string { + raw, err := json.Marshal(payload) + if err != nil { + return "" + } + sum := sha256.Sum256(raw) + return "\"" + hex.EncodeToString(sum[:]) + "\"" +} + +func parseBoolQueryWithDefault(raw string, def bool) bool { + value := strings.TrimSpace(strings.ToLower(raw)) + if value == "" { + return def + } + switch value { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + default: + return def + } +} diff --git a/backend/internal/handler/admin/snapshot_cache_test.go b/backend/internal/handler/admin/snapshot_cache_test.go new file mode 100644 index 00000000..ee3f72ca --- /dev/null +++ b/backend/internal/handler/admin/snapshot_cache_test.go @@ -0,0 +1,185 @@ +//go:build unit + +package admin + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSnapshotCache_SetAndGet(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + + entry := c.Set("key1", map[string]string{"hello": "world"}) + require.NotEmpty(t, entry.ETag) + require.NotNil(t, entry.Payload) + + got, ok := c.Get("key1") + require.True(t, ok) + require.Equal(t, entry.ETag, got.ETag) +} + +func TestSnapshotCache_Expiration(t *testing.T) { + c := newSnapshotCache(1 * time.Millisecond) + + c.Set("key1", "value") + time.Sleep(5 * time.Millisecond) + + _, ok := c.Get("key1") + require.False(t, ok, "expired entry should not be returned") +} + +func TestSnapshotCache_GetEmptyKey(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + _, ok := c.Get("") + require.False(t, ok) +} + +func TestSnapshotCache_GetMiss(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + _, ok := c.Get("nonexistent") + require.False(t, ok) +} + +func TestSnapshotCache_NilReceiver(t *testing.T) { + var c *snapshotCache + _, ok := c.Get("key") + require.False(t, ok) + + entry := c.Set("key", "value") + require.Empty(t, entry.ETag) +} + +func TestSnapshotCache_SetEmptyKey(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + + // Set with empty key should return entry but not store it + entry := c.Set("", "value") + require.NotEmpty(t, entry.ETag) + + _, ok := c.Get("") + require.False(t, ok) +} + +func TestSnapshotCache_DefaultTTL(t *testing.T) { + c := newSnapshotCache(0) + require.Equal(t, 30*time.Second, c.ttl) + + c2 := newSnapshotCache(-1 * time.Second) + require.Equal(t, 30*time.Second, c2.ttl) +} + +func TestSnapshotCache_ETagDeterministic(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + payload := map[string]int{"a": 1, "b": 2} + + entry1 := c.Set("k1", payload) + entry2 := c.Set("k2", payload) + require.Equal(t, entry1.ETag, entry2.ETag, "same payload should produce same ETag") +} + +func TestSnapshotCache_ETagFormat(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + entry := c.Set("k", "test") + // ETag should be quoted hex string: "abcdef..." + require.True(t, len(entry.ETag) > 2) + require.Equal(t, byte('"'), entry.ETag[0]) + require.Equal(t, byte('"'), entry.ETag[len(entry.ETag)-1]) +} + +func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) { + // channels are not JSON-serializable + etag := buildETagFromAny(make(chan int)) + require.Empty(t, etag) +} + +func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + var loads atomic.Int32 + + entry, hit, err := c.GetOrLoad("key1", func() (any, error) { + loads.Add(1) + return map[string]string{"hello": "world"}, nil + }) + require.NoError(t, err) + require.False(t, hit) + require.NotEmpty(t, entry.ETag) + require.Equal(t, int32(1), loads.Load()) + + entry2, hit, err := c.GetOrLoad("key1", func() (any, error) { + loads.Add(1) + return map[string]string{"unexpected": "value"}, nil + }) + require.NoError(t, err) + require.True(t, hit) + require.Equal(t, entry.ETag, entry2.ETag) + require.Equal(t, int32(1), loads.Load()) +} + +func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + var loads atomic.Int32 + start := make(chan struct{}) + const callers = 8 + errCh := make(chan error, callers) + + var wg sync.WaitGroup + wg.Add(callers) + for range callers { + go func() { + defer wg.Done() + <-start + _, _, err := c.GetOrLoad("shared", func() (any, error) { + loads.Add(1) + time.Sleep(20 * time.Millisecond) + return "value", nil + }) + errCh <- err + }() + } + close(start) + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } + + require.Equal(t, int32(1), loads.Load()) +} + +func TestParseBoolQueryWithDefault(t *testing.T) { + tests := []struct { + name string + raw string + def bool + want bool + }{ + {"empty returns default true", "", true, true}, + {"empty returns default false", "", false, false}, + {"1", "1", false, true}, + {"true", "true", false, true}, + {"TRUE", "TRUE", false, true}, + {"yes", "yes", false, true}, + {"on", "on", false, true}, + {"0", "0", true, false}, + {"false", "false", true, false}, + {"FALSE", "FALSE", true, false}, + {"no", "no", true, false}, + {"off", "off", true, false}, + {"whitespace trimmed", " true ", false, true}, + {"unknown returns default true", "maybe", true, true}, + {"unknown returns default false", "maybe", false, false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := parseBoolQueryWithDefault(tc.raw, tc.def) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index e5b6db13..342964b6 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -216,6 +216,38 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { }) } +// ResetSubscriptionQuotaRequest represents the reset quota request +type ResetSubscriptionQuotaRequest struct { + Daily bool `json:"daily"` + Weekly bool `json:"weekly"` + Monthly bool `json:"monthly"` +} + +// ResetQuota resets daily, weekly, and/or monthly usage for a subscription. +// POST /api/v1/admin/subscriptions/:id/reset-quota +func (h *SubscriptionHandler) ResetQuota(c *gin.Context) { + subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid subscription ID") + return + } + var req ResetSubscriptionQuotaRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if !req.Daily && !req.Weekly && !req.Monthly { + response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true") + return + } + sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub)) +} + // Revoke handles revoking a subscription // DELETE /api/v1/admin/subscriptions/:id func (h *SubscriptionHandler) Revoke(c *gin.Context) { diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index d0bba773..7a3135b8 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -61,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct { // GET /api/v1/admin/usage func (h *UsageHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) + exactTotal := false + if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" { + parsed, err := strconv.ParseBool(exactTotalRaw) + if err != nil { + response.BadRequest(c, "Invalid exact_total value, use true or false") + return + } + exactTotal = parsed + } // Parse filters var userID, apiKeyID, accountID, groupID int64 @@ -150,8 +159,8 @@ func (h *UsageHandler) List(c *gin.Context) { response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") return } - // Set end time to end of day - t = t.Add(24*time.Hour - time.Nanosecond) + // Use half-open range [start, end), move to next calendar day start (DST-safe). + t = t.AddDate(0, 0, 1) endTime = &t } @@ -167,6 +176,7 @@ func (h *UsageHandler) List(c *gin.Context) { BillingType: billingType, StartTime: startTime, EndTime: endTime, + ExactTotal: exactTotal, } records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters) @@ -275,7 +285,8 @@ func (h *UsageHandler) Stats(c *gin.Context) { response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") return } - endTime = endTime.Add(24*time.Hour - time.Nanosecond) + // 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。 + endTime = endTime.AddDate(0, 0, 1) } else { period := c.DefaultQuery("period", "today") switch period { diff --git a/backend/internal/handler/admin/usage_handler_request_type_test.go b/backend/internal/handler/admin/usage_handler_request_type_test.go index 21add574..3f158316 100644 --- a/backend/internal/handler/admin/usage_handler_request_type_test.go +++ b/backend/internal/handler/admin/usage_handler_request_type_test.go @@ -80,6 +80,29 @@ func TestAdminUsageListInvalidStream(t *testing.T) { require.Equal(t, http.StatusBadRequest, rec.Code) } +func TestAdminUsageListExactTotalTrue(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.True(t, repo.listFilters.ExactTotal) +} + +func TestAdminUsageListInvalidExactTotal(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + func TestAdminUsageStatsRequestTypePriority(t *testing.T) { repo := &adminUsageRepoCapture{} router := newAdminUsageRequestTypeTestRouter(repo) diff --git a/backend/internal/handler/admin/user_attribute_handler.go b/backend/internal/handler/admin/user_attribute_handler.go index 2f326279..3f84076e 100644 --- a/backend/internal/handler/admin/user_attribute_handler.go +++ b/backend/internal/handler/admin/user_attribute_handler.go @@ -1,7 +1,9 @@ package admin import ( + "encoding/json" "strconv" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -67,6 +69,8 @@ type BatchUserAttributesResponse struct { Attributes map[int64]map[int64]string `json:"attributes"` } +var userAttributesBatchCache = newSnapshotCache(30 * time.Second) + // AttributeDefinitionResponse represents attribute definition response type AttributeDefinitionResponse struct { ID int64 `json:"id"` @@ -327,16 +331,32 @@ func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) { return } - if len(req.UserIDs) == 0 { + userIDs := normalizeInt64IDList(req.UserIDs) + if len(userIDs) == 0 { response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}}) return } - attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs) + keyRaw, _ := json.Marshal(struct { + UserIDs []int64 `json:"user_ids"` + }{ + UserIDs: userIDs, + }) + cacheKey := string(keyRaw) + if cached, ok := userAttributesBatchCache.Get(cacheKey); ok { + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), userIDs) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, BatchUserAttributesResponse{Attributes: attrs}) + payload := BatchUserAttributesResponse{Attributes: attrs} + userAttributesBatchCache.Set(cacheKey, payload) + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index f85c060e..5a55ab14 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -91,6 +91,10 @@ func (h *UserHandler) List(c *gin.Context) { Search: search, Attributes: parseAttributeFilters(c), } + if raw, ok := c.GetQuery("include_subscriptions"); ok { + includeSubscriptions := parseBoolQueryWithDefault(raw, true) + filters.IncludeSubscriptions = &includeSubscriptions + } users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters) if err != nil { diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 61762744..951aed08 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -4,6 +4,7 @@ package handler import ( "context" "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -36,6 +37,11 @@ type CreateAPIKeyRequest struct { IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 Quota *float64 `json:"quota"` // 配额限制 (USD) ExpiresInDays *int `json:"expires_in_days"` // 过期天数 + + // Rate limit fields (0 = unlimited) + RateLimit5h *float64 `json:"rate_limit_5h"` + RateLimit1d *float64 `json:"rate_limit_1d"` + RateLimit7d *float64 `json:"rate_limit_7d"` } // UpdateAPIKeyRequest represents the update API key request payload @@ -48,6 +54,12 @@ type UpdateAPIKeyRequest struct { Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制 ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601) ResetQuota *bool `json:"reset_quota"` // 重置已用配额 + + // Rate limit fields (nil = no change, 0 = unlimited) + RateLimit5h *float64 `json:"rate_limit_5h"` + RateLimit1d *float64 `json:"rate_limit_1d"` + RateLimit7d *float64 `json:"rate_limit_7d"` + ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // 重置限速用量 } // List handles listing user's API keys with pagination @@ -62,7 +74,23 @@ func (h *APIKeyHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) params := pagination.PaginationParams{Page: page, PageSize: pageSize} - keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params) + // Parse filter parameters + var filters service.APIKeyListFilters + if search := strings.TrimSpace(c.Query("search")); search != "" { + if len(search) > 100 { + search = search[:100] + } + filters.Search = search + } + filters.Status = c.Query("status") + if groupIDStr := c.Query("group_id"); groupIDStr != "" { + gid, err := strconv.ParseInt(groupIDStr, 10, 64) + if err == nil { + filters.GroupID = &gid + } + } + + keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params, filters) if err != nil { response.ErrorFrom(c, err) return @@ -131,6 +159,15 @@ func (h *APIKeyHandler) Create(c *gin.Context) { if req.Quota != nil { svcReq.Quota = *req.Quota } + if req.RateLimit5h != nil { + svcReq.RateLimit5h = *req.RateLimit5h + } + if req.RateLimit1d != nil { + svcReq.RateLimit1d = *req.RateLimit1d + } + if req.RateLimit7d != nil { + svcReq.RateLimit7d = *req.RateLimit7d + } executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq) @@ -163,10 +200,14 @@ func (h *APIKeyHandler) Update(c *gin.Context) { } svcReq := service.UpdateAPIKeyRequest{ - IPWhitelist: req.IPWhitelist, - IPBlacklist: req.IPBlacklist, - Quota: req.Quota, - ResetQuota: req.ResetQuota, + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, + Quota: req.Quota, + ResetQuota: req.ResetQuota, + RateLimit5h: req.RateLimit5h, + RateLimit1d: req.RateLimit1d, + RateLimit7d: req.RateLimit7d, + ResetRateLimitUsage: req.ResetRateLimitUsage, } if req.Name != "" { svcReq.Name = &req.Name diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 1ffa9d71..f4ddf890 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) { return } + // Backend mode: only admin can login + if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { + response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") + return + } + h.respondWithTokenPair(c, user) } @@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { return } - // Delete the login session - _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken) - - // Get the user + // Get the user (before session deletion so we can check backend mode) user, err := h.userService.GetByID(c.Request.Context(), session.UserID) if err != nil { response.ErrorFrom(c, err) return } + // Backend mode: only admin can login (check BEFORE deleting session) + if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { + response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") + return + } + + // Delete the login session (only after all checks pass) + _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken) + h.respondWithTokenPair(c, user) } @@ -447,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) { return } - frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL) + frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context())) if frontendBaseURL == "" { - slog.Error("server.frontend_url not configured; cannot build password reset link") + slog.Error("frontend_url not configured in settings or config; cannot build password reset link") response.InternalError(c, "Password reset is not configured") return } @@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) { return } - tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken) + result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken) if err != nil { response.ErrorFrom(c, err) return } + // Backend mode: block non-admin token refresh + if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" { + response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") + return + } + response.Success(c, RefreshTokenResponse{ - AccessToken: tokenPair.AccessToken, - RefreshToken: tokenPair.RefreshToken, - ExpiresIn: tokenPair.ExpiresIn, + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + ExpiresIn: result.ExpiresIn, TokenType: "Bearer", }) } diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 0ccf47e4..0c7c2da7 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { email = linuxDoSyntheticEmail(subject) } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username) + // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { + if errors.Is(err, service.ErrOAuthInvitationRequired) { + pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username) + if tokenErr != nil { + redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "") + return + } + fragment := url.Values{} + fragment.Set("error", "invitation_required") + fragment.Set("pending_oauth_token", pendingToken) + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) + return + } // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) return @@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { redirectWithFragment(c, frontendCallback, fragment) } +type completeLinuxDoOAuthRequest struct { + PendingOAuthToken string `json:"pending_oauth_token" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` +} + +// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating +// the invitation code and creating the user account. +// POST /api/v1/auth/oauth/linuxdo/complete-registration +func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { + var req completeLinuxDoOAuthRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()}) + return + } + + email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"}) + return + } + + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { + response.ErrorFrom(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) { if h != nil && h.settingSvc != nil { return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx) diff --git a/backend/internal/handler/dto/announcement.go b/backend/internal/handler/dto/announcement.go index bc0db1b2..16650b8e 100644 --- a/backend/internal/handler/dto/announcement.go +++ b/backend/internal/handler/dto/announcement.go @@ -7,10 +7,11 @@ import ( ) type Announcement struct { - ID int64 `json:"id"` - Title string `json:"title"` - Content string `json:"content"` - Status string `json:"status"` + ID int64 `json:"id"` + Title string `json:"title"` + Content string `json:"content"` + Status string `json:"status"` + NotifyMode string `json:"notify_mode"` Targeting service.AnnouncementTargeting `json:"targeting"` @@ -25,9 +26,10 @@ type Announcement struct { } type UserAnnouncement struct { - ID int64 `json:"id"` - Title string `json:"title"` - Content string `json:"content"` + ID int64 `json:"id"` + Title string `json:"title"` + Content string `json:"content"` + NotifyMode string `json:"notify_mode"` StartsAt *time.Time `json:"starts_at,omitempty"` EndsAt *time.Time `json:"ends_at,omitempty"` @@ -43,17 +45,18 @@ func AnnouncementFromService(a *service.Announcement) *Announcement { return nil } return &Announcement{ - ID: a.ID, - Title: a.Title, - Content: a.Content, - Status: a.Status, - Targeting: a.Targeting, - StartsAt: a.StartsAt, - EndsAt: a.EndsAt, - CreatedBy: a.CreatedBy, - UpdatedBy: a.UpdatedBy, - CreatedAt: a.CreatedAt, - UpdatedAt: a.UpdatedAt, + ID: a.ID, + Title: a.Title, + Content: a.Content, + Status: a.Status, + NotifyMode: a.NotifyMode, + Targeting: a.Targeting, + StartsAt: a.StartsAt, + EndsAt: a.EndsAt, + CreatedBy: a.CreatedBy, + UpdatedBy: a.UpdatedBy, + CreatedAt: a.CreatedAt, + UpdatedAt: a.UpdatedAt, } } @@ -62,13 +65,14 @@ func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement return nil } return &UserAnnouncement{ - ID: a.Announcement.ID, - Title: a.Announcement.Title, - Content: a.Announcement.Content, - StartsAt: a.Announcement.StartsAt, - EndsAt: a.Announcement.EndsAt, - ReadAt: a.ReadAt, - CreatedAt: a.Announcement.CreatedAt, - UpdatedAt: a.Announcement.UpdatedAt, + ID: a.Announcement.ID, + Title: a.Announcement.Title, + Content: a.Announcement.Content, + NotifyMode: a.Announcement.NotifyMode, + StartsAt: a.Announcement.StartsAt, + EndsAt: a.Announcement.EndsAt, + ReadAt: a.ReadAt, + CreatedAt: a.Announcement.CreatedAt, + UpdatedAt: a.Announcement.UpdatedAt, } } diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index f8298067..8e5f23e7 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -71,24 +71,46 @@ func APIKeyFromService(k *service.APIKey) *APIKey { if k == nil { return nil } - return &APIKey{ - ID: k.ID, - UserID: k.UserID, - Key: k.Key, - Name: k.Name, - GroupID: k.GroupID, - Status: k.Status, - IPWhitelist: k.IPWhitelist, - IPBlacklist: k.IPBlacklist, - LastUsedAt: k.LastUsedAt, - Quota: k.Quota, - QuotaUsed: k.QuotaUsed, - ExpiresAt: k.ExpiresAt, - CreatedAt: k.CreatedAt, - UpdatedAt: k.UpdatedAt, - User: UserFromServiceShallow(k.User), - Group: GroupFromServiceShallow(k.Group), + out := &APIKey{ + ID: k.ID, + UserID: k.UserID, + Key: k.Key, + Name: k.Name, + GroupID: k.GroupID, + Status: k.Status, + IPWhitelist: k.IPWhitelist, + IPBlacklist: k.IPBlacklist, + LastUsedAt: k.LastUsedAt, + Quota: k.Quota, + QuotaUsed: k.QuotaUsed, + ExpiresAt: k.ExpiresAt, + CreatedAt: k.CreatedAt, + UpdatedAt: k.UpdatedAt, + RateLimit5h: k.RateLimit5h, + RateLimit1d: k.RateLimit1d, + RateLimit7d: k.RateLimit7d, + Usage5h: k.EffectiveUsage5h(), + Usage1d: k.EffectiveUsage1d(), + Usage7d: k.EffectiveUsage7d(), + Window5hStart: k.Window5hStart, + Window1dStart: k.Window1dStart, + Window7dStart: k.Window7dStart, + User: UserFromServiceShallow(k.User), + Group: GroupFromServiceShallow(k.Group), } + if k.Window5hStart != nil && !service.IsWindowExpired(k.Window5hStart, service.RateLimitWindow5h) { + t := k.Window5hStart.Add(service.RateLimitWindow5h) + out.Reset5hAt = &t + } + if k.Window1dStart != nil && !service.IsWindowExpired(k.Window1dStart, service.RateLimitWindow1d) { + t := k.Window1dStart.Add(service.RateLimitWindow1d) + out.Reset1dAt = &t + } + if k.Window7dStart != nil && !service.IsWindowExpired(k.Window7dStart, service.RateLimitWindow7d) { + t := k.Window7dStart.Add(service.RateLimitWindow7d) + out.Reset7dAt = &t + } + return out } func GroupFromServiceShallow(g *service.Group) *Group { @@ -117,6 +139,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { ModelRouting: g.ModelRouting, ModelRoutingEnabled: g.ModelRoutingEnabled, MCPXMLInject: g.MCPXMLInject, + DefaultMappedModel: g.DefaultMappedModel, SupportedModelScopes: g.SupportedModelScopes, AccountCount: g.AccountCount, SortOrder: g.SortOrder, @@ -155,6 +178,7 @@ func groupFromServiceBase(g *service.Group) Group { FallbackGroupID: g.FallbackGroupID, FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, + AllowMessagesDispatch: g.AllowMessagesDispatch, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } @@ -174,6 +198,7 @@ func AccountFromServiceShallow(a *service.Account) *Account { Extra: a.Extra, ProxyID: a.ProxyID, Concurrency: a.Concurrency, + LoadFactor: a.LoadFactor, Priority: a.Priority, RateMultiplier: a.BillingRateMultiplier(), Status: a.Status, @@ -216,6 +241,10 @@ func AccountFromServiceShallow(a *service.Account) *Account { buffer := a.GetRPMStickyBuffer() out.RPMStickyBuffer = &buffer } + // 用户消息队列模式 + if mode := a.GetUserMsgQueueMode(); mode != "" { + out.UserMsgQueueMode = &mode + } // TLS指纹伪装开关 if a.IsTLSFingerprintEnabled() { enabled := true @@ -235,6 +264,50 @@ func AccountFromServiceShallow(a *service.Account) *Account { } } + // 提取账号配额限制(apikey / bedrock 类型有效) + if a.IsAPIKeyOrBedrock() { + if limit := a.GetQuotaLimit(); limit > 0 { + out.QuotaLimit = &limit + used := a.GetQuotaUsed() + out.QuotaUsed = &used + } + if limit := a.GetQuotaDailyLimit(); limit > 0 { + out.QuotaDailyLimit = &limit + used := a.GetQuotaDailyUsed() + out.QuotaDailyUsed = &used + } + if limit := a.GetQuotaWeeklyLimit(); limit > 0 { + out.QuotaWeeklyLimit = &limit + used := a.GetQuotaWeeklyUsed() + out.QuotaWeeklyUsed = &used + } + // 固定时间重置配置 + if mode := a.GetQuotaDailyResetMode(); mode == "fixed" { + out.QuotaDailyResetMode = &mode + hour := a.GetQuotaDailyResetHour() + out.QuotaDailyResetHour = &hour + } + if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" { + out.QuotaWeeklyResetMode = &mode + day := a.GetQuotaWeeklyResetDay() + out.QuotaWeeklyResetDay = &day + hour := a.GetQuotaWeeklyResetHour() + out.QuotaWeeklyResetHour = &hour + } + if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" { + tz := a.GetQuotaResetTimezone() + out.QuotaResetTimezone = &tz + } + if a.Extra != nil { + if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" { + out.QuotaDailyResetAt = &v + } + if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" { + out.QuotaWeeklyResetAt = &v + } + } + } + return out } @@ -448,7 +521,10 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { AccountID: l.AccountID, RequestID: l.RequestID, Model: l.Model, + ServiceTier: l.ServiceTier, ReasoningEffort: l.ReasoningEffort, + InboundEndpoint: l.InboundEndpoint, + UpstreamEndpoint: l.UpstreamEndpoint, GroupID: l.GroupID, SubscriptionID: l.SubscriptionID, InputTokens: l.InputTokens, diff --git a/backend/internal/handler/dto/mappers_usage_test.go b/backend/internal/handler/dto/mappers_usage_test.go index d716bdc4..e4031970 100644 --- a/backend/internal/handler/dto/mappers_usage_test.go +++ b/backend/internal/handler/dto/mappers_usage_test.go @@ -71,3 +71,41 @@ func TestRequestTypeStringPtrNil(t *testing.T) { t.Parallel() require.Nil(t, requestTypeStringPtr(nil)) } + +func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) { + t.Parallel() + + serviceTier := "priority" + inboundEndpoint := "/v1/chat/completions" + upstreamEndpoint := "/v1/responses" + log := &service.UsageLog{ + RequestID: "req_3", + Model: "gpt-5.4", + ServiceTier: &serviceTier, + InboundEndpoint: &inboundEndpoint, + UpstreamEndpoint: &upstreamEndpoint, + AccountRateMultiplier: f64Ptr(1.5), + } + + userDTO := UsageLogFromService(log) + adminDTO := UsageLogFromServiceAdmin(log) + + require.NotNil(t, userDTO.ServiceTier) + require.Equal(t, serviceTier, *userDTO.ServiceTier) + require.NotNil(t, userDTO.InboundEndpoint) + require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint) + require.NotNil(t, userDTO.UpstreamEndpoint) + require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint) + require.NotNil(t, adminDTO.ServiceTier) + require.Equal(t, serviceTier, *adminDTO.ServiceTier) + require.NotNil(t, adminDTO.InboundEndpoint) + require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint) + require.NotNil(t, adminDTO.UpstreamEndpoint) + require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint) + require.NotNil(t, adminDTO.AccountRateMultiplier) + require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12) +} + +func f64Ptr(value float64) *float64 { + return &value +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index e9086010..29b00bb8 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -1,14 +1,31 @@ package dto +import ( + "encoding/json" + "strings" +) + +// CustomMenuItem represents a user-configured custom menu entry. +type CustomMenuItem struct { + ID string `json:"id"` + Label string `json:"label"` + IconSVG string `json:"icon_svg"` + URL string `json:"url"` + Visibility string `json:"visibility"` // "user" or "admin" + SortOrder int `json:"sort_order"` +} + // SystemSettings represents the admin settings API response payload. type SystemSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 - TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + FrontendURL string `json:"frontend_url"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 SMTPHost string `json:"smtp_host"` SMTPPort int `json:"smtp_port"` @@ -27,17 +44,18 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url"` - SoraClientEnabled bool `json:"sora_client_enabled"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` @@ -61,6 +79,12 @@ type SystemSettings struct { OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"` MinClaudeCodeVersion string `json:"min_claude_code_version"` + + // 分组隔离 + AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"` + + // Backend Mode + BackendModeEnabled bool `json:"backend_mode_enabled"` } type DefaultSubscriptionSetting struct { @@ -69,27 +93,30 @@ type DefaultSubscriptionSetting struct { } type PublicSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - SoraClientEnabled bool `json:"sora_client_enabled"` - Version string `json:"version"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + SoraClientEnabled bool `json:"sora_client_enabled"` + BackendModeEnabled bool `json:"backend_mode_enabled"` + Version string `json:"version"` } // SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段) @@ -138,3 +165,49 @@ type StreamTimeoutSettings struct { ThresholdCount int `json:"threshold_count"` ThresholdWindowMinutes int `json:"threshold_window_minutes"` } + +// RectifierSettings 请求整流器配置 DTO +type RectifierSettings struct { + Enabled bool `json:"enabled"` + ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` + ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` +} + +// BetaPolicyRule Beta 策略规则 DTO +type BetaPolicyRule struct { + BetaToken string `json:"beta_token"` + Action string `json:"action"` + Scope string `json:"scope"` + ErrorMessage string `json:"error_message,omitempty"` +} + +// BetaPolicySettings Beta 策略配置 DTO +type BetaPolicySettings struct { + Rules []BetaPolicyRule `json:"rules"` +} + +// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. +// Returns empty slice on empty/invalid input. +func ParseCustomMenuItems(raw string) []CustomMenuItem { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return []CustomMenuItem{} + } + var items []CustomMenuItem + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return []CustomMenuItem{} + } + return items +} + +// ParseUserVisibleMenuItems parses custom menu items and filters out admin-only entries. +func ParseUserVisibleMenuItems(raw string) []CustomMenuItem { + items := ParseCustomMenuItems(raw) + filtered := make([]CustomMenuItem, 0, len(items)) + for _, item := range items { + if item.Visibility != "admin" { + filtered = append(filtered, item) + } + } + return filtered +} diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index b5c0640f..c52e357e 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -47,6 +47,20 @@ type APIKey struct { CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` + // Rate limit fields + RateLimit5h float64 `json:"rate_limit_5h"` + RateLimit1d float64 `json:"rate_limit_1d"` + RateLimit7d float64 `json:"rate_limit_7d"` + Usage5h float64 `json:"usage_5h"` + Usage1d float64 `json:"usage_1d"` + Usage7d float64 `json:"usage_7d"` + Window5hStart *time.Time `json:"window_5h_start"` + Window1dStart *time.Time `json:"window_1d_start"` + Window7dStart *time.Time `json:"window_7d_start"` + Reset5hAt *time.Time `json:"reset_5h_at,omitempty"` + Reset1dAt *time.Time `json:"reset_1d_at,omitempty"` + Reset7dAt *time.Time `json:"reset_7d_at,omitempty"` + User *User `json:"user,omitempty"` Group *Group `json:"group,omitempty"` } @@ -85,6 +99,9 @@ type Group struct { // Sora 存储配额 SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + // OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程) + AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } @@ -101,6 +118,9 @@ type AdminGroup struct { // MCP XML 协议注入(仅 antigravity 平台使用) MCPXMLInject bool `json:"mcp_xml_inject"` + // OpenAI Messages 调度配置(仅 openai 平台使用) + DefaultMappedModel string `json:"default_mapped_model"` + // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` @@ -120,6 +140,7 @@ type Account struct { Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` Concurrency int `json:"concurrency"` + LoadFactor *int `json:"load_factor,omitempty"` Priority int `json:"priority"` RateMultiplier float64 `json:"rate_multiplier"` Status string `json:"status"` @@ -155,9 +176,10 @@ type Account struct { // RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效) // 从 extra 字段提取,方便前端显示和编辑 - BaseRPM *int `json:"base_rpm,omitempty"` - RPMStrategy *string `json:"rpm_strategy,omitempty"` - RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"` + BaseRPM *int `json:"base_rpm,omitempty"` + RPMStrategy *string `json:"rpm_strategy,omitempty"` + RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"` + UserMsgQueueMode *string `json:"user_msg_queue_mode,omitempty"` // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效) // 从 extra 字段提取,方便前端显示和编辑 @@ -173,6 +195,24 @@ type Account struct { CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"` CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"` + // API Key 账号配额限制 + QuotaLimit *float64 `json:"quota_limit,omitempty"` + QuotaUsed *float64 `json:"quota_used,omitempty"` + QuotaDailyLimit *float64 `json:"quota_daily_limit,omitempty"` + QuotaDailyUsed *float64 `json:"quota_daily_used,omitempty"` + QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"` + QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"` + + // 配额固定时间重置配置 + QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"` + QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"` + QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"` + QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"` + QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"` + QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"` + QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"` + QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` @@ -292,9 +332,15 @@ type UsageLog struct { AccountID int64 `json:"account_id"` RequestID string `json:"request_id"` Model string `json:"model"` - // ReasoningEffort is the request's reasoning effort level (OpenAI Responses API). - // nil means not provided / not applicable. + // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". + ServiceTier *string `json:"service_tier,omitempty"` + // ReasoningEffort is the request's reasoning effort level. + // OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max". ReasoningEffort *string `json:"reasoning_effort,omitempty"` + // InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions. + InboundEndpoint *string `json:"inbound_endpoint,omitempty"` + // UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses. + UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"` GroupID *int64 `json:"group_id"` SubscriptionID *int64 `json:"subscription_id"` diff --git a/backend/internal/handler/endpoint.go b/backend/internal/handler/endpoint.go new file mode 100644 index 00000000..b1200988 --- /dev/null +++ b/backend/internal/handler/endpoint.go @@ -0,0 +1,174 @@ +package handler + +import ( + "strings" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// ────────────────────────────────────────────────────────── +// Canonical inbound / upstream endpoint paths. +// All normalization and derivation reference this single set +// of constants — add new paths HERE when a new API surface +// is introduced. +// ────────────────────────────────────────────────────────── + +const ( + EndpointMessages = "/v1/messages" + EndpointChatCompletions = "/v1/chat/completions" + EndpointResponses = "/v1/responses" + EndpointGeminiModels = "/v1beta/models" +) + +// gin.Context keys used by the middleware and helpers below. +const ( + ctxKeyInboundEndpoint = "_gateway_inbound_endpoint" +) + +// ────────────────────────────────────────────────────────── +// Normalization functions +// ────────────────────────────────────────────────────────── + +// NormalizeInboundEndpoint maps a raw request path (which may carry +// prefixes like /antigravity, /openai, /sora) to its canonical form. +// +// "/antigravity/v1/messages" → "/v1/messages" +// "/v1/chat/completions" → "/v1/chat/completions" +// "/openai/v1/responses/foo" → "/v1/responses" +// "/v1beta/models/gemini:gen" → "/v1beta/models" +func NormalizeInboundEndpoint(path string) string { + path = strings.TrimSpace(path) + switch { + case strings.Contains(path, EndpointChatCompletions): + return EndpointChatCompletions + case strings.Contains(path, EndpointMessages): + return EndpointMessages + case strings.Contains(path, EndpointResponses): + return EndpointResponses + case strings.Contains(path, EndpointGeminiModels): + return EndpointGeminiModels + default: + return path + } +} + +// DeriveUpstreamEndpoint determines the upstream endpoint from the +// account platform and the normalized inbound endpoint. +// +// Platform-specific rules: +// - OpenAI always forwards to /v1/responses (with optional subpath +// such as /v1/responses/compact preserved from the raw URL). +// - Anthropic → /v1/messages +// - Gemini → /v1beta/models +// - Sora → /v1/chat/completions +// - Antigravity routes may target either Claude or Gemini, so the +// inbound endpoint is used to distinguish. +func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string { + inbound = strings.TrimSpace(inbound) + + switch platform { + case service.PlatformOpenAI: + // OpenAI forwards everything to the Responses API. + // Preserve subresource suffix (e.g. /v1/responses/compact). + if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" { + return EndpointResponses + suffix + } + return EndpointResponses + + case service.PlatformAnthropic: + return EndpointMessages + + case service.PlatformGemini: + return EndpointGeminiModels + + case service.PlatformSora: + return EndpointChatCompletions + + case service.PlatformAntigravity: + // Antigravity accounts serve both Claude and Gemini. + if inbound == EndpointGeminiModels { + return EndpointGeminiModels + } + return EndpointMessages + } + + // Unknown platform — fall back to inbound. + return inbound +} + +// responsesSubpathSuffix extracts the part after "/responses" in a raw +// request path, e.g. "/openai/v1/responses/compact" → "/compact". +// Returns "" when there is no meaningful suffix. +func responsesSubpathSuffix(rawPath string) string { + trimmed := strings.TrimRight(strings.TrimSpace(rawPath), "/") + idx := strings.LastIndex(trimmed, "/responses") + if idx < 0 { + return "" + } + suffix := trimmed[idx+len("/responses"):] + if suffix == "" || suffix == "/" { + return "" + } + if !strings.HasPrefix(suffix, "/") { + return "" + } + return suffix +} + +// ────────────────────────────────────────────────────────── +// Middleware +// ────────────────────────────────────────────────────────── + +// InboundEndpointMiddleware normalizes the request path and stores the +// canonical inbound endpoint in gin.Context so that every handler in +// the chain can read it via GetInboundEndpoint. +// +// Apply this middleware to all gateway route groups. +func InboundEndpointMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + path := c.FullPath() + if path == "" && c.Request != nil && c.Request.URL != nil { + path = c.Request.URL.Path + } + c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(path)) + c.Next() + } +} + +// ────────────────────────────────────────────────────────── +// Context helpers — used by handlers before building +// RecordUsageInput / RecordUsageLongContextInput. +// ────────────────────────────────────────────────────────── + +// GetInboundEndpoint returns the canonical inbound endpoint stored by +// InboundEndpointMiddleware. If the middleware did not run (e.g. in +// tests), it falls back to normalizing c.FullPath() on the fly. +func GetInboundEndpoint(c *gin.Context) string { + if v, ok := c.Get(ctxKeyInboundEndpoint); ok { + if s, ok := v.(string); ok && s != "" { + return s + } + } + // Fallback: normalize on the fly. + path := "" + if c != nil { + path = c.FullPath() + if path == "" && c.Request != nil && c.Request.URL != nil { + path = c.Request.URL.Path + } + } + return NormalizeInboundEndpoint(path) +} + +// GetUpstreamEndpoint derives the upstream endpoint from the context +// and the account platform. Handlers call this after scheduling an +// account, passing account.Platform. +func GetUpstreamEndpoint(c *gin.Context, platform string) string { + inbound := GetInboundEndpoint(c) + rawPath := "" + if c != nil && c.Request != nil && c.Request.URL != nil { + rawPath = c.Request.URL.Path + } + return DeriveUpstreamEndpoint(inbound, rawPath, platform) +} diff --git a/backend/internal/handler/endpoint_test.go b/backend/internal/handler/endpoint_test.go new file mode 100644 index 00000000..a3767ac4 --- /dev/null +++ b/backend/internal/handler/endpoint_test.go @@ -0,0 +1,159 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func init() { gin.SetMode(gin.TestMode) } + +// ────────────────────────────────────────────────────────── +// NormalizeInboundEndpoint +// ────────────────────────────────────────────────────────── + +func TestNormalizeInboundEndpoint(t *testing.T) { + tests := []struct { + path string + want string + }{ + // Direct canonical paths. + {"/v1/messages", EndpointMessages}, + {"/v1/chat/completions", EndpointChatCompletions}, + {"/v1/responses", EndpointResponses}, + {"/v1beta/models", EndpointGeminiModels}, + + // Prefixed paths (antigravity, openai, sora). + {"/antigravity/v1/messages", EndpointMessages}, + {"/openai/v1/responses", EndpointResponses}, + {"/openai/v1/responses/compact", EndpointResponses}, + {"/sora/v1/chat/completions", EndpointChatCompletions}, + {"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels}, + + // Gin route patterns with wildcards. + {"/v1beta/models/*modelAction", EndpointGeminiModels}, + {"/v1/responses/*subpath", EndpointResponses}, + + // Unknown path is returned as-is. + {"/v1/embeddings", "/v1/embeddings"}, + {"", ""}, + {" /v1/messages ", EndpointMessages}, + } + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + require.Equal(t, tt.want, NormalizeInboundEndpoint(tt.path)) + }) + } +} + +// ────────────────────────────────────────────────────────── +// DeriveUpstreamEndpoint +// ────────────────────────────────────────────────────────── + +func TestDeriveUpstreamEndpoint(t *testing.T) { + tests := []struct { + name string + inbound string + rawPath string + platform string + want string + }{ + // Anthropic. + {"anthropic messages", EndpointMessages, "/v1/messages", service.PlatformAnthropic, EndpointMessages}, + + // Gemini. + {"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels}, + + // Sora. + {"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions}, + + // OpenAI — always /v1/responses. + {"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses}, + {"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"}, + {"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"}, + {"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses}, + {"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses}, + + // Antigravity — uses inbound to pick Claude vs Gemini upstream. + {"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages}, + {"antigravity gemini", EndpointGeminiModels, "/antigravity/v1beta/models", service.PlatformAntigravity, EndpointGeminiModels}, + + // Unknown platform — passthrough. + {"unknown platform", "/v1/embeddings", "/v1/embeddings", "unknown", "/v1/embeddings"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, DeriveUpstreamEndpoint(tt.inbound, tt.rawPath, tt.platform)) + }) + } +} + +// ────────────────────────────────────────────────────────── +// responsesSubpathSuffix +// ────────────────────────────────────────────────────────── + +func TestResponsesSubpathSuffix(t *testing.T) { + tests := []struct { + raw string + want string + }{ + {"/v1/responses", ""}, + {"/v1/responses/", ""}, + {"/v1/responses/compact", "/compact"}, + {"/openai/v1/responses/compact/detail", "/compact/detail"}, + {"/v1/messages", ""}, + {"", ""}, + } + for _, tt := range tests { + t.Run(tt.raw, func(t *testing.T) { + require.Equal(t, tt.want, responsesSubpathSuffix(tt.raw)) + }) + } +} + +// ────────────────────────────────────────────────────────── +// InboundEndpointMiddleware + context helpers +// ────────────────────────────────────────────────────────── + +func TestInboundEndpointMiddleware(t *testing.T) { + router := gin.New() + router.Use(InboundEndpointMiddleware()) + + var captured string + router.POST("/v1/messages", func(c *gin.Context) { + captured = GetInboundEndpoint(c) + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, EndpointMessages, captured) +} + +func TestGetInboundEndpoint_FallbackWithoutMiddleware(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/antigravity/v1/messages", nil) + + // Middleware did not run — fallback to normalizing c.Request.URL.Path. + got := GetInboundEndpoint(c) + require.Equal(t, EndpointMessages, got) +} + +func TestGetUpstreamEndpoint_FullFlow(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses/compact", nil) + + // Simulate middleware. + c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(c.Request.URL.Path)) + + got := GetUpstreamEndpoint(c, service.PlatformOpenAI) + require.Equal(t, "/v1/responses/compact", got) +} diff --git a/backend/internal/handler/failover_loop.go b/backend/internal/handler/failover_loop.go index b2583301..6d8ddc72 100644 --- a/backend/internal/handler/failover_loop.go +++ b/backend/internal/handler/failover_loop.go @@ -30,7 +30,7 @@ const ( const ( // maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误) - maxSameAccountRetries = 2 + maxSameAccountRetries = 3 // sameAccountRetryDelay 同账号重试间隔 sameAccountRetryDelay = 500 * time.Millisecond // singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。 diff --git a/backend/internal/handler/failover_loop_test.go b/backend/internal/handler/failover_loop_test.go index 5a41b2dd..2c65ebc2 100644 --- a/backend/internal/handler/failover_loop_test.go +++ b/backend/internal/handler/failover_loop_test.go @@ -291,35 +291,31 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) { require.Less(t, elapsed, 2*time.Second) }) - t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) { + t.Run("达到最大重试次数前均返回FailoverContinue", func(t *testing.T) { mock := &mockTempUnscheduler{} fs := NewFailoverState(3, false) err := newTestFailoverErr(400, true, false) - // 第一次 - action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) - require.Equal(t, FailoverContinue, action) - require.Equal(t, 1, fs.SameAccountRetryCount[100]) + for i := 1; i <= maxSameAccountRetries; i++ { + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, i, fs.SameAccountRetryCount[100]) + } - // 第二次 - action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) - require.Equal(t, FailoverContinue, action) - require.Equal(t, 2, fs.SameAccountRetryCount[100]) - - require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule") + require.Empty(t, mock.calls, "达到最大重试次数前均不应调用 TempUnschedule") }) - t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) { + t.Run("超过最大重试次数后触发TempUnschedule并切换", func(t *testing.T) { mock := &mockTempUnscheduler{} fs := NewFailoverState(3, false) err := newTestFailoverErr(400, true, false) - // 第一次、第二次重试 - fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) - fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) - require.Equal(t, 2, fs.SameAccountRetryCount[100]) + for i := 0; i < maxSameAccountRetries; i++ { + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + } + require.Equal(t, maxSameAccountRetries, fs.SameAccountRetryCount[100]) - // 第三次:重试已达到 maxSameAccountRetries(2),应切换账号 + // 第 maxSameAccountRetries+1 次:重试耗尽,应切换账号 action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) require.Equal(t, FailoverContinue, action) require.Equal(t, 1, fs.SwitchCount) @@ -354,13 +350,14 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) { err := newTestFailoverErr(400, true, false) // 耗尽账号 100 的重试 - fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) - fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) - // 第三次: 重试耗尽 → 切换 + for i := 0; i < maxSameAccountRetries; i++ { + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + } + // 第 maxSameAccountRetries+1 次: 重试耗尽 → 切换 action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) require.Equal(t, FailoverContinue, action) - // 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换 + // 再次遇到账号 100,计数仍为 maxSameAccountRetries,条件不满足 → 直接切换 action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) require.Equal(t, FailoverContinue, action) require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule") @@ -386,9 +383,10 @@ func TestHandleFailoverError_TempUnschedule(t *testing.T) { fs := NewFailoverState(3, false) err := newTestFailoverErr(502, true, false) - // 耗尽重试 - fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) - fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + for i := 0; i < maxSameAccountRetries; i++ { + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + } + // 再次触发时才会执行 TempUnschedule + 切换 fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) require.Len(t, mock.calls, 1) @@ -521,17 +519,16 @@ func TestHandleFailoverError_IntegrationScenario(t *testing.T) { mock := &mockTempUnscheduler{} fs := NewFailoverState(3, true) // hasBoundSession=true - // 1. 账号 100 遇到可重试错误,同账号重试 2 次 + // 1. 账号 100 遇到可重试错误,同账号重试 maxSameAccountRetries 次 retryErr := newTestFailoverErr(400, true, false) - action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) - require.Equal(t, FailoverContinue, action) + for i := 0; i < maxSameAccountRetries; i++ { + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + } require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling") - action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) - require.Equal(t, FailoverContinue, action) - - // 2. 账号 100 重试耗尽 → TempUnschedule + 切换 - action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + // 2. 账号 100 超过重试上限 → TempUnschedule + 切换 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) require.Equal(t, FailoverContinue, action) require.Equal(t, 1, fs.SwitchCount) require.Len(t, mock.calls, 1) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 2bd59f32..831029c4 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -22,6 +22,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -45,6 +46,7 @@ type GatewayHandler struct { usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper + userMsgQueueHelper *UserMsgQueueHelper maxAccountSwitches int maxAccountSwitchesGemini int cfg *config.Config @@ -63,6 +65,7 @@ func NewGatewayHandler( apiKeyService *service.APIKeyService, usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, + userMsgQueueService *service.UserMessageQueueService, cfg *config.Config, settingService *service.SettingService, ) *GatewayHandler { @@ -78,6 +81,13 @@ func NewGatewayHandler( maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini } } + + // 初始化用户消息串行队列 helper + var umqHelper *UserMsgQueueHelper + if userMsgQueueService != nil && cfg != nil { + umqHelper = NewUserMsgQueueHelper(userMsgQueueService, SSEPingFormatClaude, pingInterval) + } + return &GatewayHandler{ gatewayService: gatewayService, geminiCompatService: geminiCompatService, @@ -89,6 +99,7 @@ func NewGatewayHandler( usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), + userMsgQueueHelper: umqHelper, maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, cfg: cfg, @@ -380,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if fs.SwitchCount > 0 { requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } + // 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover + writerSizeBeforeForward := c.Writer.Size() if account.Platform == service.PlatformAntigravity { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession) } else { @@ -391,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { + // 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化 + if c.Writer.Size() != writerSizeBeforeForward { + h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true) + return + } action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) switch action { case FailoverContinue: @@ -423,19 +441,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + + if result.ReasoningEffort == nil { + result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort) + } // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - ForceCacheBilling: fs.ForceCacheBilling, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + ForceCacheBilling: fs.ForceCacheBilling, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), @@ -566,21 +594,90 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + // ===== 用户消息串行队列 START ===== + var queueRelease func() + umqMode := h.getUserMsgQueueMode(account, parsedReq) + + switch umqMode { + case config.UMQModeSerialize: + // 串行模式:获取锁 + RPM 延迟 + 释放(当前行为不变) + baseRPM := account.GetBaseRPM() + release, qErr := h.userMsgQueueHelper.AcquireWithWait( + c, account.ID, baseRPM, reqStream, &streamStarted, + h.cfg.Gateway.UserMessageQueue.WaitTimeout(), + reqLog, + ) + if qErr != nil { + // fail-open: 记录 warn,不阻止请求 + reqLog.Warn("gateway.umq_acquire_failed", + zap.Int64("account_id", account.ID), + zap.Error(qErr), + ) + } else { + queueRelease = release + } + + case config.UMQModeThrottle: + // 软性限速:仅施加 RPM 自适应延迟,不阻塞并发 + baseRPM := account.GetBaseRPM() + if tErr := h.userMsgQueueHelper.ThrottleWithPing( + c, account.ID, baseRPM, reqStream, &streamStarted, + h.cfg.Gateway.UserMessageQueue.WaitTimeout(), + reqLog, + ); tErr != nil { + reqLog.Warn("gateway.umq_throttle_failed", + zap.Int64("account_id", account.ID), + zap.Error(tErr), + ) + } + + default: + if umqMode != "" { + reqLog.Warn("gateway.umq_unknown_mode", + zap.String("mode", umqMode), + zap.Int64("account_id", account.ID), + ) + } + } + + // 用 wrapReleaseOnDone 确保 context 取消时自动释放(仅 serialize 模式有 queueRelease) + queueRelease = wrapReleaseOnDone(c.Request.Context(), queueRelease) + // 注入回调到 ParsedRequest:使用外层 wrapper 以便提前清理 AfterFunc + parsedReq.OnUpstreamAccepted = queueRelease + // ===== 用户消息串行队列 END ===== + // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } + // 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover + writerSizeBeforeForward := c.Writer.Size() if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) } else { result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) } + + // 兜底释放串行锁(正常情况已通过回调提前释放) + if queueRelease != nil { + queueRelease() + } + // 清理回调引用,防止 failover 重试时旧回调被错误调用 + parsedReq.OnUpstreamAccepted = nil + if accountReleaseFunc != nil { accountReleaseFunc() } if err != nil { + // Beta policy block: return 400 immediately, no failover + var betaBlockedErr *service.BetaBlockedError + if errors.As(err, &betaBlockedErr) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message) + return + } + var promptTooLongErr *service.PromptTooLongError if errors.As(err, &promptTooLongErr) { reqLog.Warn("gateway.prompt_too_long_from_antigravity", @@ -626,6 +723,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { + // 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化 + if c.Writer.Size() != writerSizeBeforeForward { + h.handleFailoverExhausted(c, failoverErr, account.Platform, true) + return + } action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) switch action { case FailoverContinue: @@ -658,19 +760,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + + if result.ReasoningEffort == nil { + result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort) + } // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: currentAPIKey, - User: currentAPIKey.User, - Account: account, - Subscription: currentSubscription, - UserAgent: userAgent, - IPAddress: clientIP, - ForceCacheBilling: fs.ForceCacheBilling, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: currentAPIKey, + User: currentAPIKey.User, + Account: account, + Subscription: currentSubscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + ForceCacheBilling: fs.ForceCacheBilling, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), @@ -774,6 +886,10 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service // Usage handles getting account balance and usage statistics for CC Switch integration // GET /v1/usage +// +// Two modes: +// - quota_limited: API Key has quota or rate limits configured. Returns key-level limits/usage. +// - unrestricted: No key-level limits. Returns subscription or wallet balance info. func (h *GatewayHandler) Usage(c *gin.Context) { apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { @@ -787,54 +903,195 @@ func (h *GatewayHandler) Usage(c *gin.Context) { return } + ctx := c.Request.Context() + + // 解析可选的日期范围参数(用于 model_stats 查询) + startTime, endTime := h.parseUsageDateRange(c) + // Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应 - var usageData gin.H + usageData := h.buildUsageData(ctx, apiKey.ID) + + // Best-effort: 获取模型统计 + var modelStats any if h.usageService != nil { - dashStats, err := h.usageService.GetAPIKeyDashboardStats(c.Request.Context(), apiKey.ID) - if err == nil && dashStats != nil { - usageData = gin.H{ - "today": gin.H{ - "requests": dashStats.TodayRequests, - "input_tokens": dashStats.TodayInputTokens, - "output_tokens": dashStats.TodayOutputTokens, - "cache_creation_tokens": dashStats.TodayCacheCreationTokens, - "cache_read_tokens": dashStats.TodayCacheReadTokens, - "total_tokens": dashStats.TodayTokens, - "cost": dashStats.TodayCost, - "actual_cost": dashStats.TodayActualCost, - }, - "total": gin.H{ - "requests": dashStats.TotalRequests, - "input_tokens": dashStats.TotalInputTokens, - "output_tokens": dashStats.TotalOutputTokens, - "cache_creation_tokens": dashStats.TotalCacheCreationTokens, - "cache_read_tokens": dashStats.TotalCacheReadTokens, - "total_tokens": dashStats.TotalTokens, - "cost": dashStats.TotalCost, - "actual_cost": dashStats.TotalActualCost, - }, - "average_duration_ms": dashStats.AverageDurationMs, - "rpm": dashStats.Rpm, - "tpm": dashStats.Tpm, + if stats, err := h.usageService.GetAPIKeyModelStats(ctx, apiKey.ID, startTime, endTime); err == nil && len(stats) > 0 { + modelStats = stats + } + } + + // 判断模式: key 有总额度或速率限制 → quota_limited,否则 → unrestricted + isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits() + + if isQuotaLimited { + h.usageQuotaLimited(c, ctx, apiKey, usageData, modelStats) + return + } + + h.usageUnrestricted(c, ctx, apiKey, subject, usageData, modelStats) +} + +// parseUsageDateRange 解析 start_date / end_date query params,默认返回近 30 天范围 +func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Time) { + now := timezone.Now() + endTime := now + startTime := now.AddDate(0, 0, -30) + + if s := c.Query("start_date"); s != "" { + if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil { + startTime = t + } + } + if s := c.Query("end_date"); s != "" { + if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil { + endTime = t.AddDate(0, 0, 1) // half-open range upper bound + } + } + return startTime, endTime +} + +// buildUsageData 构建 today/total 用量摘要 +func (h *GatewayHandler) buildUsageData(ctx context.Context, apiKeyID int64) gin.H { + if h.usageService == nil { + return nil + } + dashStats, err := h.usageService.GetAPIKeyDashboardStats(ctx, apiKeyID) + if err != nil || dashStats == nil { + return nil + } + return gin.H{ + "today": gin.H{ + "requests": dashStats.TodayRequests, + "input_tokens": dashStats.TodayInputTokens, + "output_tokens": dashStats.TodayOutputTokens, + "cache_creation_tokens": dashStats.TodayCacheCreationTokens, + "cache_read_tokens": dashStats.TodayCacheReadTokens, + "total_tokens": dashStats.TodayTokens, + "cost": dashStats.TodayCost, + "actual_cost": dashStats.TodayActualCost, + }, + "total": gin.H{ + "requests": dashStats.TotalRequests, + "input_tokens": dashStats.TotalInputTokens, + "output_tokens": dashStats.TotalOutputTokens, + "cache_creation_tokens": dashStats.TotalCacheCreationTokens, + "cache_read_tokens": dashStats.TotalCacheReadTokens, + "total_tokens": dashStats.TotalTokens, + "cost": dashStats.TotalCost, + "actual_cost": dashStats.TotalActualCost, + }, + "average_duration_ms": dashStats.AverageDurationMs, + "rpm": dashStats.Rpm, + "tpm": dashStats.Tpm, + } +} + +// usageQuotaLimited 处理 quota_limited 模式的响应 +func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, modelStats any) { + resp := gin.H{ + "mode": "quota_limited", + "isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired, + "status": apiKey.Status, + } + + // 总额度信息 + if apiKey.Quota > 0 { + remaining := apiKey.GetQuotaRemaining() + resp["quota"] = gin.H{ + "limit": apiKey.Quota, + "used": apiKey.QuotaUsed, + "remaining": remaining, + "unit": "USD", + } + resp["remaining"] = remaining + resp["unit"] = "USD" + } + + // 速率限制信息(从 DB 获取实时用量) + if apiKey.HasRateLimits() && h.apiKeyService != nil { + rateLimitData, err := h.apiKeyService.GetRateLimitData(ctx, apiKey.ID) + if err == nil && rateLimitData != nil { + var rateLimits []gin.H + if apiKey.RateLimit5h > 0 { + used := rateLimitData.EffectiveUsage5h() + entry := gin.H{ + "window": "5h", + "limit": apiKey.RateLimit5h, + "used": used, + "remaining": max(0, apiKey.RateLimit5h-used), + "window_start": rateLimitData.Window5hStart, + } + if rateLimitData.Window5hStart != nil && !service.IsWindowExpired(rateLimitData.Window5hStart, service.RateLimitWindow5h) { + entry["reset_at"] = rateLimitData.Window5hStart.Add(service.RateLimitWindow5h) + } + rateLimits = append(rateLimits, entry) + } + if apiKey.RateLimit1d > 0 { + used := rateLimitData.EffectiveUsage1d() + entry := gin.H{ + "window": "1d", + "limit": apiKey.RateLimit1d, + "used": used, + "remaining": max(0, apiKey.RateLimit1d-used), + "window_start": rateLimitData.Window1dStart, + } + if rateLimitData.Window1dStart != nil && !service.IsWindowExpired(rateLimitData.Window1dStart, service.RateLimitWindow1d) { + entry["reset_at"] = rateLimitData.Window1dStart.Add(service.RateLimitWindow1d) + } + rateLimits = append(rateLimits, entry) + } + if apiKey.RateLimit7d > 0 { + used := rateLimitData.EffectiveUsage7d() + entry := gin.H{ + "window": "7d", + "limit": apiKey.RateLimit7d, + "used": used, + "remaining": max(0, apiKey.RateLimit7d-used), + "window_start": rateLimitData.Window7dStart, + } + if rateLimitData.Window7dStart != nil && !service.IsWindowExpired(rateLimitData.Window7dStart, service.RateLimitWindow7d) { + entry["reset_at"] = rateLimitData.Window7dStart.Add(service.RateLimitWindow7d) + } + rateLimits = append(rateLimits, entry) + } + if len(rateLimits) > 0 { + resp["rate_limits"] = rateLimits } } } - // 订阅模式:返回订阅限额信息 + 用量统计 + // 过期时间 + if apiKey.ExpiresAt != nil { + resp["expires_at"] = apiKey.ExpiresAt + resp["days_until_expiry"] = apiKey.GetDaysUntilExpiry() + } + + if usageData != nil { + resp["usage"] = usageData + } + if modelStats != nil { + resp["model_stats"] = modelStats + } + + c.JSON(http.StatusOK, resp) +} + +// usageUnrestricted 处理 unrestricted 模式的响应(向后兼容) +func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, modelStats any) { + // 订阅模式 if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() { - subscription, ok := middleware2.GetSubscriptionFromContext(c) - if !ok { - h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription") - return + resp := gin.H{ + "mode": "unrestricted", + "isValid": true, + "planName": apiKey.Group.Name, + "unit": "USD", } - remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription) - resp := gin.H{ - "isValid": true, - "planName": apiKey.Group.Name, - "remaining": remaining, - "unit": "USD", - "subscription": gin.H{ + // 订阅信息可能不在 context 中(/v1/usage 路径跳过了中间件的计费检查) + subscription, ok := middleware2.GetSubscriptionFromContext(c) + if ok { + remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription) + resp["remaining"] = remaining + resp["subscription"] = gin.H{ "daily_usage_usd": subscription.DailyUsageUSD, "weekly_usage_usd": subscription.WeeklyUsageUSD, "monthly_usage_usd": subscription.MonthlyUsageUSD, @@ -842,23 +1099,28 @@ func (h *GatewayHandler) Usage(c *gin.Context) { "weekly_limit_usd": apiKey.Group.WeeklyLimitUSD, "monthly_limit_usd": apiKey.Group.MonthlyLimitUSD, "expires_at": subscription.ExpiresAt, - }, + } } + if usageData != nil { resp["usage"] = usageData } + if modelStats != nil { + resp["model_stats"] = modelStats + } c.JSON(http.StatusOK, resp) return } - // 余额模式:返回钱包余额 + 用量统计 - latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + // 余额模式 + latestUser, err := h.userService.GetByID(ctx, subject.UserID) if err != nil { h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info") return } resp := gin.H{ + "mode": "unrestricted", "isValid": true, "planName": "钱包余额", "remaining": latestUser.Balance, @@ -868,6 +1130,9 @@ func (h *GatewayHandler) Usage(c *gin.Context) { if usageData != nil { resp["usage"] = usageData } + if modelStats != nil { + resp["model_stats"] = modelStats + } c.JSON(http.StatusOK, resp) } @@ -1375,6 +1640,18 @@ func billingErrorDetails(err error) (status int, code, message string) { } return http.StatusServiceUnavailable, "billing_service_error", msg } + if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) { + msg := pkgerrors.Message(err) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg + } + if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) { + msg := pkgerrors.Message(err) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg + } + if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) { + msg := pkgerrors.Message(err) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg + } msg := pkgerrors.Message(err) if msg == "" { logger.L().With( @@ -1431,3 +1708,24 @@ func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { }() task(ctx) } + +// getUserMsgQueueMode 获取当前请求的 UMQ 模式 +// 返回 "serialize" | "throttle" | "" +func (h *GatewayHandler) getUserMsgQueueMode(account *service.Account, parsed *service.ParsedRequest) string { + if h.userMsgQueueHelper == nil { + return "" + } + // 仅适用于 Anthropic OAuth/SetupToken 账号 + if !account.IsAnthropicOAuthOrSetupToken() { + return "" + } + if !service.IsRealUserMessage(parsed) { + return "" + } + // 账号级模式优先,fallback 到全局配置 + mode := account.GetUserMsgQueueMode() + if mode == "" { + mode = h.cfg.Gateway.UserMessageQueue.GetEffectiveMode() + } + return mode +} diff --git a/backend/internal/handler/gateway_handler_stream_failover_test.go b/backend/internal/handler/gateway_handler_stream_failover_test.go new file mode 100644 index 00000000..dc4b8dd2 --- /dev/null +++ b/backend/internal/handler/gateway_handler_stream_failover_test.go @@ -0,0 +1,122 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。 +const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" + + "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n" + +// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证: +// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时, +// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。 +// 具体验证: +// 1. c.Writer.Size() 检测条件正确触发(字节数已增加) +// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾 +// 3. 响应体中只出现一个 message_start,不存在第二个(防止流拼接腐化) +func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + // 步骤 1:记录 Forward 前的 writer size(模拟 writerSizeBeforeForward := c.Writer.Size()) + sizeBeforeForward := c.Writer.Size() + require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1(未写入任何字节)") + + // 步骤 2:模拟 Forward 已向客户端写入部分 SSE 内容(message_start + content_block_start) + _, err := c.Writer.Write([]byte(partialMessageStartSSE)) + require.NoError(t, err) + + // 步骤 3:验证守卫条件成立(c.Writer.Size() != sizeBeforeForward) + require.NotEqual(t, sizeBeforeForward, c.Writer.Size(), + "写入 SSE 内容后 writer size 必须增加,守卫条件应为 true") + + // 步骤 4:模拟 UpstreamFailoverError(上游在流中途返回 403) + failoverErr := &service.UpstreamFailoverError{ + StatusCode: http.StatusForbidden, + ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`), + } + + // 步骤 5:守卫触发 → 调用 handleFailoverExhausted,streamStarted=true + h := &GatewayHandler{} + h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true) + + body := w.Body.String() + + // 断言 A:响应体中包含最初写入的 message_start SSE 事件行 + require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件") + + // 断言 B:响应体以 SSE 错误事件结尾(data: {"type":"error",...}\n\n) + require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"), + "响应体应以 JSON 对象结尾(SSE error event 的 data 字段)") + require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件") + + // 断言 C:SSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化) + firstIdx := strings.Index(body, "event: message_start") + lastIdx := strings.LastIndex(body, "event: message_start") + assert.Equal(t, firstIdx, lastIdx, + "响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次") +} + +// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同, +// 验证 Gemini 路径使用 service.PlatformGemini(而非 account.Platform)时行为一致。 +func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil) + + sizeBeforeForward := c.Writer.Size() + + _, err := c.Writer.Write([]byte(partialMessageStartSSE)) + require.NoError(t, err) + + require.NotEqual(t, sizeBeforeForward, c.Writer.Size()) + + failoverErr := &service.UpstreamFailoverError{ + StatusCode: http.StatusForbidden, + } + + h := &GatewayHandler{} + h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true) + + body := w.Body.String() + + require.Contains(t, body, "event: message_start") + require.Contains(t, body, `"type":"error"`) + + firstIdx := strings.Index(body, "event: message_start") + lastIdx := strings.LastIndex(body, "event: message_start") + assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start") +} + +// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景: +// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容, +// 守卫条件(c.Writer.Size() != sizeBeforeForward)为 false,不应中止 failover。 +func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + // 模拟 writerSizeBeforeForward:初始为 -1 + sizeBeforeForward := c.Writer.Size() + + // Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前) + // c.Writer.Size() 仍为 -1 + + // 守卫条件:sizeBeforeForward == c.Writer.Size() → 不触发 + guardTriggered := c.Writer.Size() != sizeBeforeForward + require.False(t, guardTriggered, + "未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续") +} diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 2afa6440..6bcc0003 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc return result, nil } func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil } +func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil } func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) { t.Helper() @@ -138,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // accountRepo (not used: scheduler snapshot hit) &fakeGroupRepo{group: group}, nil, // usageLogRepo + nil, // usageBillingRepo nil, // userRepo nil, // userSubRepo nil, // userGroupRateRepo @@ -155,11 +157,12 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // sessionLimitCache nil, // rpmCache nil, // digestStore + nil, // settingService ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 cfg := &config.Config{RunMode: config.RunModeSimple} - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{}) concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0) diff --git a/backend/internal/handler/gateway_helper_fastpath_test.go b/backend/internal/handler/gateway_helper_fastpath_test.go index 31d489f0..c7c0fb6c 100644 --- a/backend/internal/handler/gateway_helper_fastpath_test.go +++ b/backend/internal/handler/gateway_helper_fastpath_test.go @@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a return nil } +func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { + return nil +} + func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) { cache := &concurrencyCacheMock{ acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go index f8f7eaca..9e904107 100644 --- a/backend/internal/handler/gateway_helper_hotpath_test.go +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont return nil } +func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { + return nil +} + func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 50af9c8f..cfe80911 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -503,6 +503,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ Result: result, @@ -510,8 +513,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { User: apiKey.User, Account: account, Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 ForceCacheBilling: fs.ForceCacheBilling, diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 1e1247fc..89d556cc 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -12,6 +12,7 @@ type AdminHandlers struct { Account *admin.AccountHandler Announcement *admin.AnnouncementHandler DataManagement *admin.DataManagementHandler + Backup *admin.BackupHandler OAuth *admin.OAuthHandler OpenAIOAuth *admin.OpenAIOAuthHandler GeminiOAuth *admin.GeminiOAuthHandler @@ -27,6 +28,7 @@ type AdminHandlers struct { UserAttribute *admin.UserAttributeHandler ErrorPassthrough *admin.ErrorPassthroughHandler APIKey *admin.AdminAPIKeyHandler + ScheduledTest *admin.ScheduledTestHandler } // Handlers contains all HTTP handlers diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go new file mode 100644 index 00000000..4db5cadd --- /dev/null +++ b/backend/internal/handler/openai_chat_completions.go @@ -0,0 +1,286 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ChatCompletions handles OpenAI Chat Completions API requests. +// POST /v1/chat/completions +func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { + streamStarted := false + defer h.recoverResponsesPanic(c, &streamStarted) + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.openai_gateway.chat_completions", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + if !gjson.ValidBytes(body) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqStream := gjson.GetBytes(body, "stream").Bool() + + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + setOpsRequestContext(c, reqModel, reqStream, body) + + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + routingStart := time.Now() + + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog) + if !acquired { + return + } + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + + sessionHash := h.gatewayService.GenerateSessionHash(c, body) + promptCacheKey := h.gatewayService.ExtractSessionID(c, body) + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + sameAccountRetryCount := make(map[int64]int) + var lastFailoverErr *service.UpstreamFailoverError + + for { + c.Set("openai_chat_completions_fallback_model", "") + reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + sessionHash, + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err != nil { + reqLog.Warn("openai_chat_completions.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if len(failedAccountIDs) == 0 { + defaultModel := "" + if apiKey.Group != nil { + defaultModel = apiKey.Group.DefaultMappedModel + } + if defaultModel != "" && defaultModel != reqModel { + reqLog.Info("openai_chat_completions.fallback_to_default_model", + zap.String("default_mapped_model", defaultModel), + ) + selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + sessionHash, + defaultModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err == nil && selection != nil { + c.Set("openai_chat_completions_fallback_model", defaultModel) + } + } + if err != nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) + return + } + } else { + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) + } else { + h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted) + } + return + } + } + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + account := selection.Account + sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account) + reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) + _ = scheduleDecision + setOpsSelectedAccount(c, account.ID, account.Platform) + + accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog) + if !acquired { + return + } + + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() + + defaultMappedModel := c.GetString("openai_chat_completions_fallback_model") + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) + + forwardDurationMs := time.Since(forwardStart).Milliseconds() + if accountReleaseFunc != nil { + accountReleaseFunc() + } + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + if err == nil && result != nil && result.FirstTokenMs != nil { + service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // Pool mode: retry on the same account + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai_chat_completions.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Warn("openai_chat_completions.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) + return + } + if result != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + } else { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + } + + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + }); err != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.chat_completions"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("openai_chat_completions.record_usage_failed", zap.Error(err)) + } + }) + reqLog.Debug("openai_chat_completions.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) + return + } +} diff --git a/backend/internal/handler/openai_gateway_compact_log_test.go b/backend/internal/handler/openai_gateway_compact_log_test.go new file mode 100644 index 00000000..062f318b --- /dev/null +++ b/backend/internal/handler/openai_gateway_compact_log_test.go @@ -0,0 +1,192 @@ +package handler + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +var handlerStructuredLogCaptureMu sync.Mutex + +type handlerInMemoryLogSink struct { + mu sync.Mutex + events []*logger.LogEvent +} + +func (s *handlerInMemoryLogSink) WriteLogEvent(event *logger.LogEvent) { + if event == nil { + return + } + cloned := *event + if event.Fields != nil { + cloned.Fields = make(map[string]any, len(event.Fields)) + for k, v := range event.Fields { + cloned.Fields[k] = v + } + } + s.mu.Lock() + s.events = append(s.events, &cloned) + s.mu.Unlock() +} + +func (s *handlerInMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool { + s.mu.Lock() + defer s.mu.Unlock() + wantLevel := strings.ToLower(strings.TrimSpace(level)) + for _, ev := range s.events { + if ev == nil { + continue + } + if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel { + return true + } + } + return false +} + +func (s *handlerInMemoryLogSink) ContainsFieldValue(field, substr string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, ev := range s.events { + if ev == nil || ev.Fields == nil { + continue + } + if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) { + return true + } + } + return false +} + +func captureHandlerStructuredLog(t *testing.T) (*handlerInMemoryLogSink, func()) { + t.Helper() + handlerStructuredLogCaptureMu.Lock() + + err := logger.Init(logger.InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: logger.SamplingOptions{Enabled: false}, + }) + require.NoError(t, err) + + sink := &handlerInMemoryLogSink{} + logger.SetSink(sink) + return sink, func() { + logger.SetSink(nil) + handlerStructuredLogCaptureMu.Unlock() + } +} + +func TestIsOpenAIRemoteCompactPath(t *testing.T) { + require.False(t, isOpenAIRemoteCompactPath(nil)) + + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil) + require.True(t, isOpenAIRemoteCompactPath(c)) + + c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact/", nil) + require.True(t, isOpenAIRemoteCompactPath(c)) + + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + require.False(t, isOpenAIRemoteCompactPath(c)) +} + +func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureHandlerStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0") + c.Set(opsModelKey, "gpt-5.3-codex") + c.Set(opsAccountIDKey, int64(123)) + c.Header("x-request-id", "rid-compact-ok") + c.Status(http.StatusOK) + + h := &OpenAIGatewayHandler{} + h.logOpenAIRemoteCompactOutcome(c, time.Now().Add(-8*time.Millisecond)) + + require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info")) + require.True(t, logSink.ContainsFieldValue("compact_outcome", "succeeded")) + require.True(t, logSink.ContainsFieldValue("status_code", "200")) + require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact")) + require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.3-codex")) + require.True(t, logSink.ContainsFieldValue("account_id", "123")) + require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-compact-ok")) +} + +func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureHandlerStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0") + c.Status(http.StatusBadGateway) + + h := &OpenAIGatewayHandler{} + h.logOpenAIRemoteCompactOutcome(c, time.Now()) + + require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn")) + require.True(t, logSink.ContainsFieldValue("compact_outcome", "failed")) + require.True(t, logSink.ContainsFieldValue("status_code", "502")) + require.True(t, logSink.ContainsFieldValue("path", "/responses/compact")) +} + +func TestLogOpenAIRemoteCompactOutcome_NonCompactSkips(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureHandlerStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.Status(http.StatusOK) + + h := &OpenAIGatewayHandler{} + h.logOpenAIRemoteCompactOutcome(c, time.Now()) + + require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info")) + require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn")) +} + +func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureHandlerStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`)) + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0") + + h := &OpenAIGatewayHandler{} + h.Responses(c) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn")) + require.True(t, logSink.ContainsFieldValue("status_code", "401")) + require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact")) +} diff --git a/backend/internal/handler/openai_gateway_endpoint_normalization_test.go b/backend/internal/handler/openai_gateway_endpoint_normalization_test.go new file mode 100644 index 00000000..0dacd74d --- /dev/null +++ b/backend/internal/handler/openai_gateway_endpoint_normalization_test.go @@ -0,0 +1,56 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint verifies that the +// unified GetUpstreamEndpoint helper produces the same results as the +// former normalizedOpenAIUpstreamEndpoint for OpenAI platform requests. +func TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + path string + want string + }{ + { + name: "responses root maps to responses upstream", + path: "/v1/responses", + want: EndpointResponses, + }, + { + name: "responses compact keeps compact suffix", + path: "/openai/v1/responses/compact", + want: "/v1/responses/compact", + }, + { + name: "responses nested suffix preserved", + path: "/openai/v1/responses/compact/detail", + want: "/v1/responses/compact/detail", + }, + { + name: "non responses path uses platform fallback", + path: "/v1/messages", + want: EndpointResponses, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil) + + got := GetUpstreamEndpoint(c, service.PlatformOpenAI) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 4bbd17ba..c681e61d 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -20,6 +20,7 @@ import ( coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" + "github.com/google/uuid" "github.com/tidwall/gjson" "go.uber.org/zap" ) @@ -33,6 +34,7 @@ type OpenAIGatewayHandler struct { errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int + cfg *config.Config } // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler @@ -61,6 +63,7 @@ func NewOpenAIGatewayHandler( errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, + cfg: cfg, } } @@ -70,6 +73,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。 streamStarted := false defer h.recoverResponsesPanic(c, &streamStarted) + compactStartedAt := time.Now() + defer h.logOpenAIRemoteCompactOutcome(c, compactStartedAt) setOpenAIClientTransportHTTP(c) requestStart := time.Now() @@ -114,6 +119,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } setOpsRequestContext(c, "", false, body) + sessionHashBody := body + if service.IsOpenAIResponsesCompactPathForTest(c) { + if compactSeed := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()); compactSeed != "" { + c.Set(service.OpenAICompactSessionSeedKeyForTest(), compactSeed) + } + normalizedCompactBody, normalizedCompact, compactErr := service.NormalizeOpenAICompactRequestBodyForTest(body) + if compactErr != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to normalize compact request body") + return + } + if normalizedCompact { + body = normalizedCompactBody + } + } // 校验请求体 JSON 合法性 if !gjson.ValidBytes(body) { @@ -189,11 +208,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } // Generate session hash (header first; fallback to prompt_cache_key) - sessionHash := h.gatewayService.GenerateSessionHash(c, body) + sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody) maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) + sameAccountRetryCount := make(map[int64]int) var lastFailoverErr *service.UpstreamFailoverError for { @@ -241,6 +261,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { zap.Float64("load_skew", scheduleDecision.LoadSkew), ) account := selection.Account + sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account) reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) setOpsSelectedAccount(c, account.ID, account.Platform) @@ -270,6 +291,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // 池模式:同账号重试 + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } h.gatewayService.RecordOpenAIAccountSwitch() failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr @@ -301,6 +341,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } if result != nil { + if account.Type == service.AccountTypeOAuth { + h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders) + } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) } else { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) @@ -309,18 +352,22 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), @@ -340,6 +387,431 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } } +func isOpenAIRemoteCompactPath(c *gin.Context) bool { + if c == nil || c.Request == nil || c.Request.URL == nil { + return false + } + normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/") + return strings.HasSuffix(normalizedPath, "/responses/compact") +} + +func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, startedAt time.Time) { + if !isOpenAIRemoteCompactPath(c) { + return + } + + var ( + ctx = context.Background() + path string + status int + ) + if c != nil { + if c.Request != nil { + ctx = c.Request.Context() + if c.Request.URL != nil { + path = strings.TrimSpace(c.Request.URL.Path) + } + } + if c.Writer != nil { + status = c.Writer.Status() + } + } + + outcome := "failed" + if status >= 200 && status < 300 { + outcome = "succeeded" + } + latencyMs := time.Since(startedAt).Milliseconds() + if latencyMs < 0 { + latencyMs = 0 + } + + fields := []zap.Field{ + zap.String("component", "handler.openai_gateway.responses"), + zap.Bool("remote_compact", true), + zap.String("compact_outcome", outcome), + zap.Int("status_code", status), + zap.Int64("latency_ms", latencyMs), + zap.String("path", path), + zap.Bool("force_codex_cli", h != nil && h.cfg != nil && h.cfg.Gateway.ForceCodexCLI), + } + + if c != nil { + if userAgent := strings.TrimSpace(c.GetHeader("User-Agent")); userAgent != "" { + fields = append(fields, zap.String("request_user_agent", userAgent)) + } + if v, ok := c.Get(opsModelKey); ok { + if model, ok := v.(string); ok && strings.TrimSpace(model) != "" { + fields = append(fields, zap.String("request_model", strings.TrimSpace(model))) + } + } + if v, ok := c.Get(opsAccountIDKey); ok { + if accountID, ok := v.(int64); ok && accountID > 0 { + fields = append(fields, zap.Int64("account_id", accountID)) + } + } + if c.Writer != nil { + if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("x-request-id")); upstreamRequestID != "" { + fields = append(fields, zap.String("upstream_request_id", upstreamRequestID)) + } else if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("X-Request-Id")); upstreamRequestID != "" { + fields = append(fields, zap.String("upstream_request_id", upstreamRequestID)) + } + } + } + + log := logger.FromContext(ctx).With(fields...) + if outcome == "succeeded" { + log.Info("codex.remote_compact.succeeded") + return + } + log.Warn("codex.remote_compact.failed") +} + +// Messages handles Anthropic Messages API requests routed to OpenAI platform. +// POST /v1/messages (when group platform is OpenAI) +func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { + streamStarted := false + defer h.recoverAnthropicMessagesPanic(c, &streamStarted) + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.anthropicErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.openai_gateway.messages", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + // 检查分组是否允许 /v1/messages 调度 + if apiKey.Group != nil && !apiKey.Group.AllowMessagesDispatch { + h.anthropicErrorResponse(c, http.StatusForbidden, "permission_error", + "This group does not allow /v1/messages dispatch") + return + } + + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.anthropicErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + if !gjson.ValidBytes(body) { + h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqStream := gjson.GetBytes(body, "stream").Bool() + + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + setOpsRequestContext(c, reqModel, reqStream, body) + + // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + routingStart := time.Now() + + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog) + if !acquired { + return + } + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.anthropicStreamingAwareError(c, status, code, message, streamStarted) + return + } + + sessionHash := h.gatewayService.GenerateSessionHash(c, body) + promptCacheKey := h.gatewayService.ExtractSessionID(c, body) + + // Anthropic 格式的请求在 metadata.user_id 中携带 session 标识, + // 而非 OpenAI 的 session_id/conversation_id headers。 + // 从中派生 sessionHash(sticky session)和 promptCacheKey(upstream cache)。 + if sessionHash == "" || promptCacheKey == "" { + if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" { + seed := reqModel + "-" + userID + if promptCacheKey == "" { + promptCacheKey = service.GenerateSessionUUID(seed) + } + if sessionHash == "" { + sessionHash = service.DeriveSessionHashFromSeed(seed) + } + } + } + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + sameAccountRetryCount := make(map[int64]int) + var lastFailoverErr *service.UpstreamFailoverError + + for { + // 清除上一次迭代的降级模型标记,避免残留影响本次迭代 + c.Set("openai_messages_fallback_model", "") + reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", // no previous_response_id + sessionHash, + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err != nil { + reqLog.Warn("openai_messages.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + // 首次调度失败 + 有默认映射模型 → 用默认模型重试 + if len(failedAccountIDs) == 0 { + defaultModel := "" + if apiKey.Group != nil { + defaultModel = apiKey.Group.DefaultMappedModel + } + if defaultModel != "" && defaultModel != reqModel { + reqLog.Info("openai_messages.fallback_to_default_model", + zap.String("default_mapped_model", defaultModel), + ) + selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + sessionHash, + defaultModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err == nil && selection != nil { + c.Set("openai_messages_fallback_model", defaultModel) + } + } + if err != nil { + h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) + return + } + } else { + if lastFailoverErr != nil { + h.handleAnthropicFailoverExhausted(c, lastFailoverErr, streamStarted) + } else { + h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted) + } + return + } + } + if selection == nil || selection.Account == nil { + h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + account := selection.Account + sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account) + reqLog.Debug("openai_messages.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) + _ = scheduleDecision + setOpsSelectedAccount(c, account.ID, account.Platform) + + accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog) + if !acquired { + return + } + + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() + + // 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时, + // 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。 + defaultMappedModel := c.GetString("openai_messages_fallback_model") + result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) + + forwardDurationMs := time.Since(forwardStart).Milliseconds() + if accountReleaseFunc != nil { + accountReleaseFunc() + } + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + if err == nil && result != nil && result.FirstTokenMs != nil { + service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // 池模式:同账号重试 + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai_messages.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai_messages.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted) + reqLog.Warn("openai_messages.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) + return + } + if result != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + } else { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + } + + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, + }); err != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.messages"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("openai_messages.record_usage_failed", zap.Error(err)) + } + }) + reqLog.Debug("openai_messages.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) + return + } +} + +// anthropicErrorResponse writes an error in Anthropic Messages API format. +func (h *OpenAIGatewayHandler) anthropicErrorResponse(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +// anthropicStreamingAwareError handles errors that may occur during streaming, +// using Anthropic SSE error format. +func (h *OpenAIGatewayHandler) anthropicStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { + if streamStarted { + flusher, ok := c.Writer.(http.Flusher) + if ok { + errPayload, _ := json.Marshal(gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": message, + }, + }) + fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errPayload) //nolint:errcheck + flusher.Flush() + } + return + } + h.anthropicErrorResponse(c, status, errType, message) +} + +// handleAnthropicFailoverExhausted maps upstream failover errors to Anthropic format. +func (h *OpenAIGatewayHandler) handleAnthropicFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) { + status, errType, errMsg := h.mapUpstreamError(failoverErr.StatusCode) + h.anthropicStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +// ensureAnthropicErrorResponse writes a fallback Anthropic error if no response was written. +func (h *OpenAIGatewayHandler) ensureAnthropicErrorResponse(c *gin.Context, streamStarted bool) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted) + return true +} + func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool { if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { return true @@ -756,17 +1228,23 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { if turnErr != nil || result == nil { return } + if account.Type == service.AccountTypeOAuth { + h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders) + } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) h.submitUsageRecordTask(func(taskCtx context.Context) { if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), + APIKeyService: h.apiKeyService, }); err != nil { reqLog.Error("openai.websocket_record_usage_failed", zap.Int64("account_id", account.ID), @@ -817,6 +1295,26 @@ func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStart ) } +// recoverAnthropicMessagesPanic recovers from panics in the Anthropic Messages +// handler and returns an Anthropic-formatted error response. +func (h *OpenAIGatewayHandler) recoverAnthropicMessagesPanic(c *gin.Context, streamStarted *bool) { + recovered := recover() + if recovered == nil { + return + } + + started := streamStarted != nil && *streamStarted + requestLogger(c, "handler.openai_gateway.messages").Error( + "openai.messages_panic_recovered", + zap.Bool("stream_started", started), + zap.Any("panic", recovered), + zap.ByteString("stack", debug.Stack()), + ) + if !started { + h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "Internal server error") + } +} + func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool { missing := h.missingResponsesDependencies() if len(missing) == 0 { @@ -1022,6 +1520,14 @@ func setOpenAIClientTransportWS(c *gin.Context) { service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS) } +func ensureOpenAIPoolModeSessionHash(sessionHash string, account *service.Account) string { + if sessionHash != "" || account == nil || !account.IsPoolMode() { + return sessionHash + } + // 为当前请求生成一次性粘性会话键,确保同账号重试不会重新负载均衡到其他账号。 + return "openai-pool-retry-" + uuid.NewString() +} + func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string { gid := int64(0) if groupID != nil { diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index 2f53d655..ceb06f0e 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -26,11 +26,28 @@ const ( opsStreamKey = "ops_stream" opsRequestBodyKey = "ops_request_body" opsAccountIDKey = "ops_account_id" + + // 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用 + opsErrContextCanceled = "context canceled" + opsErrNoAvailableAccounts = "no available accounts" + opsErrInvalidAPIKey = "invalid_api_key" + opsErrAPIKeyRequired = "api_key_required" + opsErrInsufficientBalance = "insufficient balance" + opsErrInsufficientAccountBalance = "insufficient account balance" + opsErrInsufficientQuota = "insufficient_quota" + + // 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited) + opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE" + opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED" + opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND" + opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID" + opsCodeUserInactive = "USER_INACTIVE" ) const ( opsErrorLogTimeout = 5 * time.Second opsErrorLogDrainTimeout = 10 * time.Second + opsErrorLogBatchWindow = 200 * time.Millisecond opsErrorLogMinWorkerCount = 4 opsErrorLogMaxWorkerCount = 32 @@ -38,6 +55,7 @@ const ( opsErrorLogQueueSizePerWorker = 128 opsErrorLogMinQueueSize = 256 opsErrorLogMaxQueueSize = 8192 + opsErrorLogBatchSize = 32 ) type opsErrorLogJob struct { @@ -82,27 +100,82 @@ func startOpsErrorLogWorkers() { for i := 0; i < workerCount; i++ { go func() { defer opsErrorLogWorkersWg.Done() - for job := range opsErrorLogQueue { - opsErrorLogQueueLen.Add(-1) - if job.ops == nil || job.entry == nil { - continue + for { + job, ok := <-opsErrorLogQueue + if !ok { + return } - func() { - defer func() { - if r := recover(); r != nil { - log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack()) + opsErrorLogQueueLen.Add(-1) + batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize) + batch = append(batch, job) + + timer := time.NewTimer(opsErrorLogBatchWindow) + batchLoop: + for len(batch) < opsErrorLogBatchSize { + select { + case nextJob, ok := <-opsErrorLogQueue: + if !ok { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + flushOpsErrorLogBatch(batch) + return } - }() - ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout) - _ = job.ops.RecordError(ctx, job.entry, nil) - cancel() - opsErrorLogProcessed.Add(1) - }() + opsErrorLogQueueLen.Add(-1) + batch = append(batch, nextJob) + case <-timer.C: + break batchLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + flushOpsErrorLogBatch(batch) } }() } } +func flushOpsErrorLogBatch(batch []opsErrorLogJob) { + if len(batch) == 0 { + return + } + defer func() { + if r := recover(); r != nil { + log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack()) + } + }() + + grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch)) + var processed int64 + for _, job := range batch { + if job.ops == nil || job.entry == nil { + continue + } + grouped[job.ops] = append(grouped[job.ops], job.entry) + processed++ + } + if processed == 0 { + return + } + + for opsSvc, entries := range grouped { + if opsSvc == nil || len(entries) == 0 { + continue + } + ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout) + _ = opsSvc.RecordErrorBatch(ctx, entries) + cancel() + } + opsErrorLogProcessed.Add(processed) +} + func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) { if ops == nil || entry == nil { return @@ -967,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string { return errType } switch strings.TrimSpace(code) { - case "INSUFFICIENT_BALANCE": + case opsCodeInsufficientBalance: return "billing_error" - case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID": + case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid: return "subscription_error" default: return "api_error" @@ -981,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string { // Standardized phases: request|auth|routing|upstream|network|internal // Map billing/concurrency/response => request; scheduling => routing. switch strings.TrimSpace(code) { - case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID": + case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid: return "request" } @@ -1000,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string { case "upstream_error", "overloaded_error": return "upstream" case "api_error": - if strings.Contains(msg, "no available accounts") { + if strings.Contains(msg, opsErrNoAvailableAccounts) { return "routing" } return "internal" @@ -1046,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool { func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool { switch strings.TrimSpace(code) { - case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE": + case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive: return true } if phase == "billing" || phase == "concurrency" { @@ -1140,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message // Check if context canceled errors should be ignored (client disconnects) if settings.IgnoreContextCanceled { - if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") { + if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) { return true } } // Check if "no available accounts" errors should be ignored if settings.IgnoreNoAvailableAccounts { - if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") { + if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) { return true } } // Check if invalid/missing API key errors should be ignored (user misconfiguration) if settings.IgnoreInvalidApiKeyErrors { - if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") { + if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) { + return true + } + } + + // Check if insufficient balance errors should be ignored + if settings.IgnoreInsufficientBalanceErrors { + if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) || + strings.Contains(bodyLower, opsErrInsufficientQuota) || + strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) { return true } } diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 2141a9ee..92061895 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -32,26 +32,29 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { } response.Success(c, dto.PublicSettings{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - PromoCodeEnabled: settings.PromoCodeEnabled, - PasswordResetEnabled: settings.PasswordResetEnabled, - InvitationCodeEnabled: settings.InvitationCodeEnabled, - TotpEnabled: settings.TotpEnabled, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - HomeContent: settings.HomeContent, - HideCcsImportButton: settings.HideCcsImportButton, - PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, - PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, - SoraClientEnabled: settings.SoraClientEnabled, - Version: h.version, + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: settings.PromoCodeEnabled, + PasswordResetEnabled: settings.PasswordResetEnabled, + InvitationCodeEnabled: settings.InvitationCodeEnabled, + TotpEnabled: settings.TotpEnabled, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, + PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, + PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + SoraClientEnabled: settings.SoraClientEnabled, + BackendModeEnabled: settings.BackendModeEnabled, + Version: h.version, }) } diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index 5df7fa0a..dab17673 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -996,7 +996,7 @@ func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*se } func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil } func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil } -func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { +func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { @@ -1032,6 +1032,15 @@ func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64 func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error { return nil } +func (r *stubAPIKeyRepoForHandler) IncrementRateLimitUsage(context.Context, int64, float64) error { + return nil +} +func (r *stubAPIKeyRepoForHandler) ResetRateLimitWindows(context.Context, int64) error { + return nil +} +func (r *stubAPIKeyRepoForHandler) GetRateLimitData(context.Context, int64) (*service.APIKeyRateLimitData, error) { + return nil, nil +} // newTestAPIKeyService 创建测试用的 APIKeyService func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService { @@ -2089,6 +2098,12 @@ func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { return r.accounts, nil } +func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) { + return r.accounts, nil +} func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { return nil } @@ -2117,6 +2132,14 @@ func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service return 0, nil } +func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error { + return nil +} + +func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error { + return nil +} + // ==================== Stub: SoraClient (用于 SoraGatewayService) ==================== var _ service.SoraClient = (*stubSoraClientForHandler)(nil) @@ -2183,8 +2206,8 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac // newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。 func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { return service.NewGatewayService( - accountRepo, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ) } diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 48c1e451..dc301ce1 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -399,17 +399,23 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, }); err != nil { logger.L().With( zap.String("component", "handler.sora_gateway.chat_completions"), diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 355cdb7a..7170415d 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -182,6 +182,12 @@ func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platfo func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { return r.ListSchedulableByPlatforms(ctx, platforms) } +func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return r.ListSchedulableByPlatform(ctx, platform) +} +func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + return r.ListSchedulableByPlatforms(ctx, platforms) +} func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } @@ -210,6 +216,14 @@ func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s return 0, nil } +func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return nil +} + +func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error { + return nil +} + func (r *stubAccountRepo) listSchedulable() []service.Account { var result []service.Account for _, acc := range r.accounts { @@ -320,6 +334,14 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { return nil, nil } + +func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) { + return []usagestats.EndpointStat{}, nil +} + +func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) { + return []usagestats.EndpointStat{}, nil +} func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { return nil, nil } @@ -329,6 +351,9 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) { return nil, nil } +func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) { + return nil, nil +} func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { return nil, nil } @@ -405,7 +430,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { deferredService := service.NewDeferredService(accountRepo, nil, 0) billingService := service.NewBillingService(cfg, nil) concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{}) - billingCacheService := service.NewBillingCacheService(nil, nil, nil, cfg) + billingCacheService := service.NewBillingCacheService(nil, nil, nil, nil, cfg) t.Cleanup(func() { billingCacheService.Stop() }) @@ -417,6 +442,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, nil, nil, + nil, testutil.StubGatewayCache{}, cfg, nil, @@ -431,6 +457,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { testutil.StubSessionLimitCache{}, nil, // rpmCache nil, // digestStore + nil, // settingService ) soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index 2bd0e0d7..483f5105 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -114,8 +114,8 @@ func (h *UsageHandler) List(c *gin.Context) { response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") return } - // Set end time to end of day - t = t.Add(24*time.Hour - time.Nanosecond) + // Use half-open range [start, end), move to next calendar day start (DST-safe). + t = t.AddDate(0, 0, 1) endTime = &t } @@ -227,8 +227,8 @@ func (h *UsageHandler) Stats(c *gin.Context) { response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") return } - // 设置结束时间为当天结束 - endTime = endTime.Add(24*time.Hour - time.Nanosecond) + // 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。 + endTime = endTime.AddDate(0, 0, 1) } else { // 使用 period 参数 period := c.DefaultQuery("period", "today") diff --git a/backend/internal/handler/user_msg_queue_helper.go b/backend/internal/handler/user_msg_queue_helper.go new file mode 100644 index 00000000..50449b13 --- /dev/null +++ b/backend/internal/handler/user_msg_queue_helper.go @@ -0,0 +1,237 @@ +package handler + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// UserMsgQueueHelper 用户消息串行队列 Handler 层辅助 +// 复用 ConcurrencyHelper 的退避 + SSE ping 模式 +type UserMsgQueueHelper struct { + queueService *service.UserMessageQueueService + pingFormat SSEPingFormat + pingInterval time.Duration +} + +// NewUserMsgQueueHelper 创建用户消息串行队列辅助 +func NewUserMsgQueueHelper( + queueService *service.UserMessageQueueService, + pingFormat SSEPingFormat, + pingInterval time.Duration, +) *UserMsgQueueHelper { + if pingInterval <= 0 { + pingInterval = defaultPingInterval + } + return &UserMsgQueueHelper{ + queueService: queueService, + pingFormat: pingFormat, + pingInterval: pingInterval, + } +} + +// AcquireWithWait 等待获取串行锁,流式请求期间发送 SSE ping +// 返回的 releaseFunc 内部使用 sync.Once,确保只执行一次释放 +func (h *UserMsgQueueHelper) AcquireWithWait( + c *gin.Context, + accountID int64, + baseRPM int, + isStream bool, + streamStarted *bool, + timeout time.Duration, + reqLog *zap.Logger, +) (releaseFunc func(), err error) { + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + // 先尝试立即获取 + result, err := h.queueService.TryAcquire(ctx, accountID) + if err != nil { + return nil, err // fail-open 已在 service 层处理 + } + + if result.Acquired { + // 获取成功,执行 RPM 自适应延迟 + if err := h.queueService.EnforceDelay(ctx, accountID, baseRPM); err != nil { + if ctx.Err() != nil { + // 延迟期间 context 取消,释放锁 + bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second) + _ = h.queueService.Release(bgCtx, accountID, result.RequestID) + bgCancel() + return nil, ctx.Err() + } + } + reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID)) + return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil + } + + // 需要等待:指数退避轮询 + return h.waitForLockWithPing(c, ctx, accountID, baseRPM, isStream, streamStarted, reqLog) +} + +// waitForLockWithPing 等待获取锁,流式请求期间发送 SSE ping +func (h *UserMsgQueueHelper) waitForLockWithPing( + c *gin.Context, + ctx context.Context, + accountID int64, + baseRPM int, + isStream bool, + streamStarted *bool, + reqLog *zap.Logger, +) (func(), error) { + needPing := isStream && h.pingFormat != "" + + var flusher http.Flusher + if needPing { + var ok bool + flusher, ok = c.Writer.(http.Flusher) + if !ok { + needPing = false + } + } + + var pingCh <-chan time.Time + if needPing { + pingTicker := time.NewTicker(h.pingInterval) + defer pingTicker.Stop() + pingCh = pingTicker.C + } + + backoff := initialBackoff + timer := time.NewTimer(backoff) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("umq wait timeout for account %d", accountID) + + case <-pingCh: + if !*streamStarted { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + *streamStarted = true + } + if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil { + return nil, err + } + flusher.Flush() + + case <-timer.C: + result, err := h.queueService.TryAcquire(ctx, accountID) + if err != nil { + return nil, err + } + if result.Acquired { + // 获取成功,执行 RPM 自适应延迟 + if delayErr := h.queueService.EnforceDelay(ctx, accountID, baseRPM); delayErr != nil { + if ctx.Err() != nil { + bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second) + _ = h.queueService.Release(bgCtx, accountID, result.RequestID) + bgCancel() + return nil, ctx.Err() + } + } + reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID)) + return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil + } + backoff = nextBackoff(backoff) + timer.Reset(backoff) + } + } +} + +// makeReleaseFunc 创建锁释放函数(使用 sync.Once 确保只执行一次) +func (h *UserMsgQueueHelper) makeReleaseFunc(accountID int64, requestID string, reqLog *zap.Logger) func() { + var once sync.Once + return func() { + once.Do(func() { + bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer bgCancel() + if err := h.queueService.Release(bgCtx, accountID, requestID); err != nil { + reqLog.Warn("gateway.umq_release_failed", + zap.Int64("account_id", accountID), + zap.Error(err), + ) + } else { + reqLog.Debug("gateway.umq_lock_released", zap.Int64("account_id", accountID)) + } + }) + } +} + +// ThrottleWithPing 软性限速模式:施加 RPM 自适应延迟,流式期间发送 SSE ping +// 不获取串行锁,不阻塞并发。返回后即可转发请求。 +func (h *UserMsgQueueHelper) ThrottleWithPing( + c *gin.Context, + accountID int64, + baseRPM int, + isStream bool, + streamStarted *bool, + timeout time.Duration, + reqLog *zap.Logger, +) error { + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + delay := h.queueService.CalculateRPMAwareDelay(ctx, accountID, baseRPM) + if delay <= 0 { + return nil + } + + reqLog.Debug("gateway.umq_throttle_delay", + zap.Int64("account_id", accountID), + zap.Duration("delay", delay), + ) + + // 延迟期间发送 SSE ping(复用 waitForLockWithPing 的 ping 逻辑) + needPing := isStream && h.pingFormat != "" + var flusher http.Flusher + if needPing { + flusher, _ = c.Writer.(http.Flusher) + if flusher == nil { + needPing = false + } + } + + var pingCh <-chan time.Time + if needPing { + pingTicker := time.NewTicker(h.pingInterval) + defer pingTicker.Stop() + pingCh = pingTicker.C + } + + timer := time.NewTimer(delay) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-pingCh: + // SSE ping 逻辑(与 waitForLockWithPing 一致) + if !*streamStarted { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + *streamStarted = true + } + if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil { + return err + } + flusher.Flush() + case <-timer.C: + return nil + } + } +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 76f5a979..f3aadcf3 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -15,6 +15,7 @@ func ProvideAdminHandlers( accountHandler *admin.AccountHandler, announcementHandler *admin.AnnouncementHandler, dataManagementHandler *admin.DataManagementHandler, + backupHandler *admin.BackupHandler, oauthHandler *admin.OAuthHandler, openaiOAuthHandler *admin.OpenAIOAuthHandler, geminiOAuthHandler *admin.GeminiOAuthHandler, @@ -30,6 +31,7 @@ func ProvideAdminHandlers( userAttributeHandler *admin.UserAttributeHandler, errorPassthroughHandler *admin.ErrorPassthroughHandler, apiKeyHandler *admin.AdminAPIKeyHandler, + scheduledTestHandler *admin.ScheduledTestHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -38,6 +40,7 @@ func ProvideAdminHandlers( Account: accountHandler, Announcement: announcementHandler, DataManagement: dataManagementHandler, + Backup: backupHandler, OAuth: oauthHandler, OpenAIOAuth: openaiOAuthHandler, GeminiOAuth: geminiOAuthHandler, @@ -53,6 +56,7 @@ func ProvideAdminHandlers( UserAttribute: userAttributeHandler, ErrorPassthrough: errorPassthroughHandler, APIKey: apiKeyHandler, + ScheduledTest: scheduledTestHandler, } } @@ -126,6 +130,7 @@ var ProviderSet = wire.NewSet( admin.NewAccountHandler, admin.NewAnnouncementHandler, admin.NewDataManagementHandler, + admin.NewBackupHandler, admin.NewOAuthHandler, admin.NewOpenAIOAuthHandler, admin.NewGeminiOAuthHandler, @@ -141,6 +146,7 @@ var ProviderSet = wire.NewSet( admin.NewUserAttributeHandler, admin.NewErrorPassthroughHandler, admin.NewAdminAPIKeyHandler, + admin.NewScheduledTestHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 7cc68060..8ea87f18 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -159,6 +159,8 @@ var claudeModels = []modelDef{ // Antigravity 支持的 Gemini 模型 var geminiModels = []modelDef{ {ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"}, + {ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"}, + {ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"}, {ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"}, {ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"}, {ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"}, diff --git a/backend/internal/pkg/antigravity/claude_types_test.go b/backend/internal/pkg/antigravity/claude_types_test.go index f7cb0a24..9fc09b1b 100644 --- a/backend/internal/pkg/antigravity/claude_types_test.go +++ b/backend/internal/pkg/antigravity/claude_types_test.go @@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) { requiredIDs := []string{ "claude-opus-4-6-thinking", + "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview", "gemini-3.1-flash-image", "gemini-3.1-flash-image-preview", "gemini-3-pro-image", // legacy compatibility diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 1998221a..af3a0bfc 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -14,8 +14,21 @@ import ( "net/url" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" ) +// ForbiddenError 表示上游返回 403 Forbidden +type ForbiddenError struct { + StatusCode int + Body string +} + +func (e *ForbiddenError) Error() string { + return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body) +} + // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) { // 构建 URL,流式请求添加 ?alt=sse 参数 @@ -111,10 +124,68 @@ type IneligibleTier struct { type LoadCodeAssistResponse struct { CloudAICompanionProject string `json:"cloudaicompanionProject"` CurrentTier *TierInfo `json:"currentTier,omitempty"` - PaidTier *TierInfo `json:"paidTier,omitempty"` + PaidTier *PaidTierInfo `json:"paidTier,omitempty"` IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"` } +// PaidTierInfo 付费等级信息,包含 AI Credits 余额。 +type PaidTierInfo struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + AvailableCredits []AvailableCredit `json:"availableCredits,omitempty"` +} + +// UnmarshalJSON 兼容 paidTier 既可能是字符串也可能是对象的情况。 +func (p *PaidTierInfo) UnmarshalJSON(data []byte) error { + data = bytes.TrimSpace(data) + if len(data) == 0 || string(data) == "null" { + return nil + } + if data[0] == '"' { + var id string + if err := json.Unmarshal(data, &id); err != nil { + return err + } + p.ID = id + return nil + } + type alias PaidTierInfo + var raw alias + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + *p = PaidTierInfo(raw) + return nil +} + +// AvailableCredit 表示一条 AI Credits 余额记录。 +type AvailableCredit struct { + CreditType string `json:"creditType,omitempty"` + CreditAmount string `json:"creditAmount,omitempty"` + MinimumCreditAmountForUsage string `json:"minimumCreditAmountForUsage,omitempty"` +} + +// GetAmount 将 creditAmount 解析为浮点数。 +func (c *AvailableCredit) GetAmount() float64 { + if c.CreditAmount == "" { + return 0 + } + var value float64 + _, _ = fmt.Sscanf(c.CreditAmount, "%f", &value) + return value +} + +// GetMinimumAmount 将 minimumCreditAmountForUsage 解析为浮点数。 +func (c *AvailableCredit) GetMinimumAmount() float64 { + if c.MinimumCreditAmountForUsage == "" { + return 0 + } + var value float64 + _, _ = fmt.Sscanf(c.MinimumCreditAmountForUsage, "%f", &value) + return value +} + // OnboardUserRequest onboardUser 请求 type OnboardUserRequest struct { TierID string `json:"tierId"` @@ -144,27 +215,39 @@ func (r *LoadCodeAssistResponse) GetTier() string { return "" } +// GetAvailableCredits 返回 paid tier 中的 AI Credits 余额列表。 +func (r *LoadCodeAssistResponse) GetAvailableCredits() []AvailableCredit { + if r.PaidTier == nil { + return nil + } + return r.PaidTier.AvailableCredits +} + // Client Antigravity API 客户端 type Client struct { httpClient *http.Client } -func NewClient(proxyURL string) *Client { +func NewClient(proxyURL string) (*Client, error) { client := &http.Client{ Timeout: 30 * time.Second, } - if strings.TrimSpace(proxyURL) != "" { - if proxyURLParsed, err := url.Parse(proxyURL); err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURLParsed), - } + _, parsed, err := proxyurl.Parse(proxyURL) + if err != nil { + return nil, err + } + if parsed != nil { + transport := &http.Transport{} + if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { + return nil, fmt.Errorf("configure proxy: %w", err) } + client.Transport = transport } return &Client{ httpClient: client, - } + }, nil } // isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) @@ -507,7 +590,20 @@ type ModelQuotaInfo struct { // ModelInfo 模型信息 type ModelInfo struct { - QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"` + QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"` + DisplayName string `json:"displayName,omitempty"` + SupportsImages *bool `json:"supportsImages,omitempty"` + SupportsThinking *bool `json:"supportsThinking,omitempty"` + ThinkingBudget *int `json:"thinkingBudget,omitempty"` + Recommended *bool `json:"recommended,omitempty"` + MaxTokens *int `json:"maxTokens,omitempty"` + MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` + SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"` +} + +// DeprecatedModelInfo 废弃模型转发信息 +type DeprecatedModelInfo struct { + NewModelID string `json:"newModelId"` } // FetchAvailableModelsRequest fetchAvailableModels 请求 @@ -517,7 +613,8 @@ type FetchAvailableModelsRequest struct { // FetchAvailableModelsResponse fetchAvailableModels 响应 type FetchAvailableModelsResponse struct { - Models map[string]ModelInfo `json:"models"` + Models map[string]ModelInfo `json:"models"` + DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"` } // FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON @@ -566,6 +663,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI continue } + if resp.StatusCode == http.StatusForbidden { + return nil, nil, &ForbiddenError{ + StatusCode: resp.StatusCode, + Body: string(respBodyBytes), + } + } + if resp.StatusCode != http.StatusOK { return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) } diff --git a/backend/internal/pkg/antigravity/client_test.go b/backend/internal/pkg/antigravity/client_test.go index 394b6128..7d5bba93 100644 --- a/backend/internal/pkg/antigravity/client_test.go +++ b/backend/internal/pkg/antigravity/client_test.go @@ -190,7 +190,7 @@ func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) { func TestGetTier_PaidTier优先(t *testing.T) { resp := &LoadCodeAssistResponse{ CurrentTier: &TierInfo{ID: "free-tier"}, - PaidTier: &TierInfo{ID: "g1-pro-tier"}, + PaidTier: &PaidTierInfo{ID: "g1-pro-tier"}, } if got := resp.GetTier(); got != "g1-pro-tier" { t.Errorf("应返回 paidTier: got %s", got) @@ -209,7 +209,7 @@ func TestGetTier_回退到CurrentTier(t *testing.T) { func TestGetTier_PaidTier为空ID(t *testing.T) { resp := &LoadCodeAssistResponse{ CurrentTier: &TierInfo{ID: "free-tier"}, - PaidTier: &TierInfo{ID: ""}, + PaidTier: &PaidTierInfo{ID: ""}, } // paidTier.ID 为空时应回退到 currentTier if got := resp.GetTier(); got != "free-tier" { @@ -217,6 +217,32 @@ func TestGetTier_PaidTier为空ID(t *testing.T) { } } +func TestGetAvailableCredits(t *testing.T) { + resp := &LoadCodeAssistResponse{ + PaidTier: &PaidTierInfo{ + ID: "g1-pro-tier", + AvailableCredits: []AvailableCredit{ + { + CreditType: "GOOGLE_ONE_AI", + CreditAmount: "25", + MinimumCreditAmountForUsage: "5", + }, + }, + }, + } + + credits := resp.GetAvailableCredits() + if len(credits) != 1 { + t.Fatalf("AI Credits 数量不匹配: got %d", len(credits)) + } + if credits[0].GetAmount() != 25 { + t.Errorf("CreditAmount 解析不正确: got %v", credits[0].GetAmount()) + } + if credits[0].GetMinimumAmount() != 5 { + t.Errorf("MinimumCreditAmountForUsage 解析不正确: got %v", credits[0].GetMinimumAmount()) + } +} + func TestGetTier_两者都为nil(t *testing.T) { resp := &LoadCodeAssistResponse{} if got := resp.GetTier(); got != "" { @@ -228,8 +254,20 @@ func TestGetTier_两者都为nil(t *testing.T) { // NewClient // --------------------------------------------------------------------------- +func mustNewClient(t *testing.T, proxyURL string) *Client { + t.Helper() + client, err := NewClient(proxyURL) + if err != nil { + t.Fatalf("NewClient(%q) failed: %v", proxyURL, err) + } + return client +} + func TestNewClient_无代理(t *testing.T) { - client := NewClient("") + client, err := NewClient("") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } if client == nil { t.Fatal("NewClient 返回 nil") } @@ -246,7 +284,10 @@ func TestNewClient_无代理(t *testing.T) { } func TestNewClient_有代理(t *testing.T) { - client := NewClient("http://proxy.example.com:8080") + client, err := NewClient("http://proxy.example.com:8080") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } if client == nil { t.Fatal("NewClient 返回 nil") } @@ -256,7 +297,10 @@ func TestNewClient_有代理(t *testing.T) { } func TestNewClient_空格代理(t *testing.T) { - client := NewClient(" ") + client, err := NewClient(" ") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } if client == nil { t.Fatal("NewClient 返回 nil") } @@ -267,15 +311,13 @@ func TestNewClient_空格代理(t *testing.T) { } func TestNewClient_无效代理URL(t *testing.T) { - // 无效 URL 时 url.Parse 不一定返回错误(Go 的 url.Parse 很宽容), - // 但 ://invalid 会导致解析错误 - client := NewClient("://invalid") - if client == nil { - t.Fatal("NewClient 返回 nil") + // 无效 URL 应返回 error + _, err := NewClient("://invalid") + if err == nil { + t.Fatal("无效代理 URL 应返回错误") } - // 无效 URL 解析失败时,Transport 应保持 nil - if client.httpClient.Transport != nil { - t.Error("无效代理 URL 时 Transport 应为 nil") + if !strings.Contains(err.Error(), "invalid proxy URL") { + t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error()) } } @@ -499,7 +541,7 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { defaultClientSecret = "" t.Cleanup(func() { defaultClientSecret = old }) - client := NewClient("") + client := mustNewClient(t, "") _, err := client.ExchangeCode(context.Background(), "code", "verifier") if err == nil { t.Fatal("缺少 client_secret 时应返回错误") @@ -602,7 +644,7 @@ func TestClient_RefreshToken_无ClientSecret(t *testing.T) { defaultClientSecret = "" t.Cleanup(func() { defaultClientSecret = old }) - client := NewClient("") + client := mustNewClient(t, "") _, err := client.RefreshToken(context.Background(), "refresh-tok") if err == nil { t.Fatal("缺少 client_secret 时应返回错误") @@ -1242,7 +1284,7 @@ func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token") if err != nil { t.Fatalf("LoadCodeAssist 失败: %v", err) @@ -1277,7 +1319,7 @@ func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.LoadCodeAssist(context.Background(), "bad-token") if err == nil { t.Fatal("服务器返回 403 时应返回错误") @@ -1300,7 +1342,7 @@ func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.LoadCodeAssist(context.Background(), "token") if err == nil { t.Fatal("无效 JSON 响应应返回错误") @@ -1333,7 +1375,7 @@ func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, _, err := client.LoadCodeAssist(context.Background(), "token") if err != nil { t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err) @@ -1361,7 +1403,7 @@ func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.LoadCodeAssist(context.Background(), "token") if err == nil { t.Fatal("所有 URL 都失败时应返回错误") @@ -1377,7 +1419,7 @@ func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -1441,7 +1483,7 @@ func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc") if err != nil { t.Fatalf("FetchAvailableModels 失败: %v", err) @@ -1496,7 +1538,7 @@ func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj") if err == nil { t.Fatal("服务器返回 403 时应返回错误") @@ -1516,7 +1558,7 @@ func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err == nil { t.Fatal("无效 JSON 响应应返回错误") @@ -1546,7 +1588,7 @@ func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err) @@ -1574,7 +1616,7 @@ func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err == nil { t.Fatal("所有 URL 都失败时应返回错误") @@ -1590,7 +1632,7 @@ func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -1610,7 +1652,7 @@ func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 失败: %v", err) @@ -1646,7 +1688,7 @@ func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, _, err := client.LoadCodeAssist(context.Background(), "token") if err != nil { t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err) @@ -1672,7 +1714,7 @@ func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := NewClient("") + client := mustNewClient(t, "") resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err) diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index 0ff24a1f..1a0ca5bb 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -189,6 +189,5 @@ var DefaultStopSequences = []string{ "<|user|>", "<|endoftext|>", "<|end_of_turn|>", - "[DONE]", "\n\nHuman:", } diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index 18310655..5bda31ac 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -49,12 +49,11 @@ const ( antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" ) -// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6 -var defaultUserAgentVersion = "1.19.6" +// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4 +var defaultUserAgentVersion = "1.20.4" // defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 -// 默认值使用占位符,生产环境请通过环境变量注入真实值。 -var defaultClientSecret = "GOCSPX-your-client-secret" +var defaultClientSecret = "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" func init() { // 从环境变量读取版本号,未设置则使用默认值 diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go index 2a2a52e9..f4630b09 100644 --- a/backend/internal/pkg/antigravity/oauth_test.go +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -684,13 +684,13 @@ func TestConstants_值正确(t *testing.T) { if err != nil { t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err) } - if secret != "GOCSPX-your-client-secret" { + if secret != "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" { t.Errorf("默认 client_secret 不匹配: got %s", secret) } if RedirectURI != "http://localhost:8085/callback" { t.Errorf("RedirectURI 不匹配: got %s", RedirectURI) } - if GetUserAgent() != "antigravity/1.19.6 windows/amd64" { + if GetUserAgent() != "antigravity/1.20.4 windows/amd64" { t.Errorf("UserAgent 不匹配: got %s", GetUserAgent()) } if SessionTTL != 30*time.Minute { diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index 677435ad..deed5f92 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -119,23 +119,33 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { return result.Bytes() } -// Finish 结束处理,返回最终事件和用量 +// Finish 结束处理,返回最终事件和用量。 +// 若整个流未收到任何可解析的上游数据(messageStartSent == false), +// 则不补发任何结束事件,防止客户端收到没有 message_start 的残缺流。 func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) { - var result bytes.Buffer - - if !p.messageStopSent { - _, _ = result.Write(p.emitFinish("")) - } - usage := &ClaudeUsage{ InputTokens: p.inputTokens, OutputTokens: p.outputTokens, CacheReadInputTokens: p.cacheReadTokens, } + if !p.messageStartSent { + return nil, usage + } + + var result bytes.Buffer + if !p.messageStopSent { + _, _ = result.Write(p.emitFinish("")) + } + return result.Bytes(), usage } +// MessageStartSent 报告流中是否已发出过 message_start 事件(即是否收到过有效的上游数据) +func (p *StreamingProcessor) MessageStartSent() bool { + return p.messageStartSent +} + // emitMessageStart 发送 message_start 事件 func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte { if p.messageStartSent { diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go new file mode 100644 index 00000000..2db65572 --- /dev/null +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -0,0 +1,1010 @@ +package apicompat + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// AnthropicToResponses tests +// --------------------------------------------------------------------------- + +func TestAnthropicToResponses_BasicText(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Stream: true, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Hello"`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + assert.Equal(t, "gpt-5.2", resp.Model) + assert.True(t, resp.Stream) + assert.Equal(t, 1024, *resp.MaxOutputTokens) + assert.False(t, *resp.Store) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + assert.Equal(t, "user", items[0].Role) +} + +func TestAnthropicToResponses_SystemPrompt(t *testing.T) { + t.Run("string", func(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 100, + System: json.RawMessage(`"You are helpful."`), + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + assert.Equal(t, "system", items[0].Role) + }) + + t.Run("array", func(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 100, + System: json.RawMessage(`[{"type":"text","text":"Part 1"},{"type":"text","text":"Part 2"}]`), + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + assert.Equal(t, "system", items[0].Role) + // System text should be joined with double newline. + var text string + require.NoError(t, json.Unmarshal(items[0].Content, &text)) + assert.Equal(t, "Part 1\n\nPart 2", text) + }) +} + +func TestAnthropicToResponses_ToolUse(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"What is the weather?"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"Let me check."},{"type":"tool_use","id":"call_1","name":"get_weather","input":{"city":"NYC"}}]`)}, + {Role: "user", Content: json.RawMessage(`[{"type":"tool_result","tool_use_id":"call_1","content":"Sunny, 72°F"}]`)}, + }, + Tools: []AnthropicTool{ + {Name: "get_weather", Description: "Get weather", InputSchema: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + // Check tools + require.Len(t, resp.Tools, 1) + assert.Equal(t, "function", resp.Tools[0].Type) + assert.Equal(t, "get_weather", resp.Tools[0].Name) + + // Check input items + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + assistant + function_call + function_call_output = 4 + require.Len(t, items, 4) + + assert.Equal(t, "user", items[0].Role) + assert.Equal(t, "assistant", items[1].Role) + assert.Equal(t, "function_call", items[2].Type) + assert.Equal(t, "fc_call_1", items[2].CallID) + assert.Empty(t, items[2].ID) + assert.Equal(t, "function_call_output", items[3].Type) + assert.Equal(t, "fc_call_1", items[3].CallID) + assert.Equal(t, "Sunny, 72°F", items[3].Output) +} + +func TestAnthropicToResponses_ThinkingIgnored(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Hello"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"deep thought"},{"type":"text","text":"Hi!"}]`)}, + {Role: "user", Content: json.RawMessage(`"More"`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + assistant(text only, thinking ignored) + user = 3 + require.Len(t, items, 3) + assert.Equal(t, "assistant", items[1].Role) + // Assistant content should only have text, not thinking. + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[1].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "output_text", parts[0].Type) + assert.Equal(t, "Hi!", parts[0].Text) +} + +func TestAnthropicToResponses_MaxTokensFloor(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 10, // below minMaxOutputTokens (128) + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + assert.Equal(t, 128, *resp.MaxOutputTokens) +} + +// --------------------------------------------------------------------------- +// ResponsesToAnthropic (non-streaming) tests +// --------------------------------------------------------------------------- + +func TestResponsesToAnthropic_TextOnly(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_123", + Model: "gpt-5.2", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "Hello there!"}, + }, + }, + }, + Usage: &ResponsesUsage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15}, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + assert.Equal(t, "resp_123", anth.ID) + assert.Equal(t, "claude-opus-4-6", anth.Model) + assert.Equal(t, "end_turn", anth.StopReason) + require.Len(t, anth.Content, 1) + assert.Equal(t, "text", anth.Content[0].Type) + assert.Equal(t, "Hello there!", anth.Content[0].Text) + assert.Equal(t, 10, anth.Usage.InputTokens) + assert.Equal(t, 5, anth.Usage.OutputTokens) +} + +func TestResponsesToAnthropic_ToolUse(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_456", + Model: "gpt-5.2", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "Let me check."}, + }, + }, + { + Type: "function_call", + CallID: "call_1", + Name: "get_weather", + Arguments: `{"city":"NYC"}`, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + assert.Equal(t, "tool_use", anth.StopReason) + require.Len(t, anth.Content, 2) + assert.Equal(t, "text", anth.Content[0].Type) + assert.Equal(t, "tool_use", anth.Content[1].Type) + assert.Equal(t, "call_1", anth.Content[1].ID) + assert.Equal(t, "get_weather", anth.Content[1].Name) +} + +func TestResponsesToAnthropic_Reasoning(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_789", + Model: "gpt-5.2", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "reasoning", + Summary: []ResponsesSummary{ + {Type: "summary_text", Text: "Thinking about the answer..."}, + }, + }, + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "42"}, + }, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + require.Len(t, anth.Content, 2) + assert.Equal(t, "thinking", anth.Content[0].Type) + assert.Equal(t, "Thinking about the answer...", anth.Content[0].Thinking) + assert.Equal(t, "text", anth.Content[1].Type) + assert.Equal(t, "42", anth.Content[1].Text) +} + +func TestResponsesToAnthropic_Incomplete(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_inc", + Model: "gpt-5.2", + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{ + Reason: "max_output_tokens", + }, + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{{Type: "output_text", Text: "Partial..."}}, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + assert.Equal(t, "max_tokens", anth.StopReason) +} + +func TestResponsesToAnthropic_EmptyOutput(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_empty", + Model: "gpt-5.2", + Status: "completed", + Output: []ResponsesOutput{}, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + require.Len(t, anth.Content, 1) + assert.Equal(t, "text", anth.Content[0].Type) + assert.Equal(t, "", anth.Content[0].Text) +} + +// --------------------------------------------------------------------------- +// Streaming: ResponsesEventToAnthropicEvents tests +// --------------------------------------------------------------------------- + +func TestStreamingTextOnly(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + // 1. response.created + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ + ID: "resp_1", + Model: "gpt-5.2", + }, + }, state) + require.Len(t, events, 1) + assert.Equal(t, "message_start", events[0].Type) + + // 2. output_item.added (message) + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 0, + Item: &ResponsesOutput{Type: "message"}, + }, state) + assert.Len(t, events, 0) // message item doesn't emit events + + // 3. text delta + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "Hello", + }, state) + require.Len(t, events, 2) // content_block_start + content_block_delta + assert.Equal(t, "content_block_start", events[0].Type) + assert.Equal(t, "text", events[0].ContentBlock.Type) + assert.Equal(t, "content_block_delta", events[1].Type) + assert.Equal(t, "text_delta", events[1].Delta.Type) + assert.Equal(t, "Hello", events[1].Delta.Text) + + // 4. more text + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: " world", + }, state) + require.Len(t, events, 1) // only delta, no new block start + assert.Equal(t, "content_block_delta", events[0].Type) + + // 5. text done + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.done", + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_stop", events[0].Type) + + // 6. completed + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 10, OutputTokens: 5}, + }, + }, state) + require.Len(t, events, 2) // message_delta + message_stop + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "end_turn", events[0].Delta.StopReason) + assert.Equal(t, 10, events[0].Usage.InputTokens) + assert.Equal(t, 5, events[0].Usage.OutputTokens) + assert.Equal(t, "message_stop", events[1].Type) +} + +func TestStreamingToolCall(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + // 1. response.created + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_2", Model: "gpt-5.2"}, + }, state) + + // 2. function_call added + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 0, + Item: &ResponsesOutput{Type: "function_call", CallID: "call_1", Name: "get_weather"}, + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_start", events[0].Type) + assert.Equal(t, "tool_use", events[0].ContentBlock.Type) + assert.Equal(t, "call_1", events[0].ContentBlock.ID) + + // 3. arguments delta + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 0, + Delta: `{"city":`, + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_delta", events[0].Type) + assert.Equal(t, "input_json_delta", events[0].Delta.Type) + assert.Equal(t, `{"city":`, events[0].Delta.PartialJSON) + + // 4. arguments done + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.done", + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_stop", events[0].Type) + + // 5. completed with tool_calls + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 20, OutputTokens: 10}, + }, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "tool_use", events[0].Delta.StopReason) +} + +func TestStreamingReasoning(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_3", Model: "gpt-5.2"}, + }, state) + + // reasoning item added + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 0, + Item: &ResponsesOutput{Type: "reasoning"}, + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_start", events[0].Type) + assert.Equal(t, "thinking", events[0].ContentBlock.Type) + + // reasoning text delta + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.reasoning_summary_text.delta", + OutputIndex: 0, + Delta: "Let me think...", + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_delta", events[0].Type) + assert.Equal(t, "thinking_delta", events[0].Delta.Type) + assert.Equal(t, "Let me think...", events[0].Delta.Thinking) + + // reasoning done + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.reasoning_summary_text.done", + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_stop", events[0].Type) +} + +func TestStreamingIncomplete(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_4", Model: "gpt-5.2"}, + }, state) + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "Partial output...", + }, state) + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.incomplete", + Response: &ResponsesResponse{ + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Usage: &ResponsesUsage{InputTokens: 100, OutputTokens: 4096}, + }, + }, state) + + // Should close the text block + message_delta + message_stop + require.Len(t, events, 3) + assert.Equal(t, "content_block_stop", events[0].Type) + assert.Equal(t, "message_delta", events[1].Type) + assert.Equal(t, "max_tokens", events[1].Delta.StopReason) + assert.Equal(t, "message_stop", events[2].Type) +} + +func TestFinalizeStream_NeverStarted(t *testing.T) { + state := NewResponsesEventToAnthropicState() + events := FinalizeResponsesAnthropicStream(state) + assert.Nil(t, events) +} + +func TestFinalizeStream_AlreadyCompleted(t *testing.T) { + state := NewResponsesEventToAnthropicState() + state.MessageStartSent = true + state.MessageStopSent = true + events := FinalizeResponsesAnthropicStream(state) + assert.Nil(t, events) +} + +func TestFinalizeStream_AbnormalTermination(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + // Simulate a stream that started but never completed + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_5", Model: "gpt-5.2"}, + }, state) + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "Interrupted...", + }, state) + + // Stream ends without response.completed + events := FinalizeResponsesAnthropicStream(state) + require.Len(t, events, 3) // content_block_stop + message_delta + message_stop + assert.Equal(t, "content_block_stop", events[0].Type) + assert.Equal(t, "message_delta", events[1].Type) + assert.Equal(t, "end_turn", events[1].Delta.StopReason) + assert.Equal(t, "message_stop", events[2].Type) +} + +func TestStreamingEmptyResponse(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_6", Model: "gpt-5.2"}, + }, state) + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 5, OutputTokens: 0}, + }, + }, state) + + require.Len(t, events, 2) // message_delta + message_stop + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "end_turn", events[0].Delta.StopReason) +} + +func TestResponsesAnthropicEventToSSE(t *testing.T) { + evt := AnthropicStreamEvent{ + Type: "message_start", + Message: &AnthropicResponse{ + ID: "resp_1", + Type: "message", + Role: "assistant", + }, + } + sse, err := ResponsesAnthropicEventToSSE(evt) + require.NoError(t, err) + assert.Contains(t, sse, "event: message_start\n") + assert.Contains(t, sse, "data: ") + assert.Contains(t, sse, `"resp_1"`) +} + +// --------------------------------------------------------------------------- +// response.failed tests +// --------------------------------------------------------------------------- + +func TestStreamingFailed(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + // 1. response.created + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_fail_1", Model: "gpt-5.2"}, + }, state) + + // 2. Some text output before failure + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "Partial output before failure", + }, state) + + // 3. response.failed + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.failed", + Response: &ResponsesResponse{ + Status: "failed", + Error: &ResponsesError{Code: "server_error", Message: "Internal error"}, + Usage: &ResponsesUsage{InputTokens: 50, OutputTokens: 10}, + }, + }, state) + + // Should close text block + message_delta + message_stop + require.Len(t, events, 3) + assert.Equal(t, "content_block_stop", events[0].Type) + assert.Equal(t, "message_delta", events[1].Type) + assert.Equal(t, "end_turn", events[1].Delta.StopReason) + assert.Equal(t, 50, events[1].Usage.InputTokens) + assert.Equal(t, 10, events[1].Usage.OutputTokens) + assert.Equal(t, "message_stop", events[2].Type) +} + +func TestStreamingFailedNoOutput(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + // 1. response.created + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_fail_2", Model: "gpt-5.2"}, + }, state) + + // 2. response.failed with no prior output + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.failed", + Response: &ResponsesResponse{ + Status: "failed", + Error: &ResponsesError{Code: "rate_limit_error", Message: "Too many requests"}, + Usage: &ResponsesUsage{InputTokens: 20, OutputTokens: 0}, + }, + }, state) + + // Should emit message_delta + message_stop (no block to close) + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "end_turn", events[0].Delta.StopReason) + assert.Equal(t, "message_stop", events[1].Type) +} + +func TestResponsesToAnthropic_Failed(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_fail_3", + Model: "gpt-5.2", + Status: "failed", + Error: &ResponsesError{Code: "server_error", Message: "Something went wrong"}, + Output: []ResponsesOutput{}, + Usage: &ResponsesUsage{InputTokens: 30, OutputTokens: 0}, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + // Failed status defaults to "end_turn" stop reason + assert.Equal(t, "end_turn", anth.StopReason) + // Should have at least an empty text block + require.Len(t, anth.Content, 1) + assert.Equal(t, "text", anth.Content[0].Type) +} + +// --------------------------------------------------------------------------- +// thinking → reasoning conversion tests +// --------------------------------------------------------------------------- + +func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + // thinking.type is ignored for effort; default xhigh applies. + assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) + assert.Contains(t, resp.Include, "reasoning.encrypted_content") + assert.NotContains(t, resp.Include, "reasoning.summary") +} + +func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + Thinking: &AnthropicThinking{Type: "adaptive", BudgetTokens: 5000}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + // thinking.type is ignored for effort; default xhigh applies. + assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) + assert.NotContains(t, resp.Include, "reasoning.summary") +} + +func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + Thinking: &AnthropicThinking{Type: "disabled"}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + // Default effort applies (high → xhigh) even when thinking is disabled. + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "xhigh", resp.Reasoning.Effort) +} + +func TestAnthropicToResponses_NoThinking(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + // Default effort applies (high → xhigh) when no thinking/output_config is set. + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "xhigh", resp.Reasoning.Effort) +} + +// --------------------------------------------------------------------------- +// output_config.effort override tests +// --------------------------------------------------------------------------- + +func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) { + // Default is xhigh, but output_config.effort="low" overrides. low→low after mapping. + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000}, + OutputConfig: &AnthropicOutputConfig{Effort: "low"}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "low", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) +} + +func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) { + // No thinking field, but output_config.effort="medium" → creates reasoning. + // medium→high after mapping. + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + OutputConfig: &AnthropicOutputConfig{Effort: "medium"}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) +} + +func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) { + // output_config.effort="high" → mapped to "xhigh". + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + OutputConfig: &AnthropicOutputConfig{Effort: "high"}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) +} + +func TestAnthropicToResponses_NoOutputConfig(t *testing.T) { + // No output_config → default xhigh regardless of thinking.type. + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "xhigh", resp.Reasoning.Effort) +} + +func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) { + // output_config present but effort empty (e.g. only format set) → default xhigh. + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + OutputConfig: &AnthropicOutputConfig{}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "xhigh", resp.Reasoning.Effort) +} + +// --------------------------------------------------------------------------- +// tool_choice conversion tests +// --------------------------------------------------------------------------- + +func TestAnthropicToResponses_ToolChoiceAuto(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + ToolChoice: json.RawMessage(`{"type":"auto"}`), + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var tc string + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "auto", tc) +} + +func TestAnthropicToResponses_ToolChoiceAny(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + ToolChoice: json.RawMessage(`{"type":"any"}`), + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var tc string + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "required", tc) +} + +func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + ToolChoice: json.RawMessage(`{"type":"tool","name":"get_weather"}`), + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var tc map[string]any + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "function", tc["type"]) + fn, ok := tc["function"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "get_weather", fn["name"]) +} + +// --------------------------------------------------------------------------- +// Image content block conversion tests +// --------------------------------------------------------------------------- + +func TestAnthropicToResponses_UserImageBlock(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`[ + {"type":"text","text":"What is in this image?"}, + {"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBOR"}} + ]`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + assert.Equal(t, "user", items[0].Role) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 2) + assert.Equal(t, "input_text", parts[0].Type) + assert.Equal(t, "What is in this image?", parts[0].Text) + assert.Equal(t, "input_image", parts[1].Type) + assert.Equal(t, "data:image/png;base64,iVBOR", parts[1].ImageURL) +} + +func TestAnthropicToResponses_ImageOnlyUserMessage(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`[ + {"type":"image","source":{"type":"base64","media_type":"image/jpeg","data":"/9j/4AAQ"}} + ]`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "input_image", parts[0].Type) + assert.Equal(t, "data:image/jpeg;base64,/9j/4AAQ", parts[0].ImageURL) +} + +func TestAnthropicToResponses_ToolResultWithImage(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Read the screenshot"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"toolu_1","name":"Read","input":{"file_path":"/tmp/screen.png"}}]`)}, + {Role: "user", Content: json.RawMessage(`[ + {"type":"tool_result","tool_use_id":"toolu_1","content":[ + {"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBOR"}} + ]} + ]`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + function_call + function_call_output + user(image) = 4 + require.Len(t, items, 4) + + // function_call_output should have text-only output (no image). + assert.Equal(t, "function_call_output", items[2].Type) + assert.Equal(t, "fc_toolu_1", items[2].CallID) + assert.Equal(t, "(empty)", items[2].Output) + + // Image should be in a separate user message. + assert.Equal(t, "user", items[3].Role) + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[3].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "input_image", parts[0].Type) + assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL) +} + +func TestAnthropicToResponses_ToolResultMixed(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Describe the file"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"toolu_2","name":"Read","input":{"file_path":"/tmp/photo.png"}}]`)}, + {Role: "user", Content: json.RawMessage(`[ + {"type":"tool_result","tool_use_id":"toolu_2","content":[ + {"type":"text","text":"File metadata: 800x600 PNG"}, + {"type":"image","source":{"type":"base64","media_type":"image/png","data":"AAAA"}} + ]} + ]`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + function_call + function_call_output + user(image) = 4 + require.Len(t, items, 4) + + // function_call_output should have text-only output. + assert.Equal(t, "function_call_output", items[2].Type) + assert.Equal(t, "File metadata: 800x600 PNG", items[2].Output) + + // Image should be in a separate user message. + assert.Equal(t, "user", items[3].Role) + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[3].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "input_image", parts[0].Type) + assert.Equal(t, "data:image/png;base64,AAAA", parts[0].ImageURL) +} + +func TestAnthropicToResponses_TextOnlyToolResultBackwardCompat(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Check weather"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"call_1","name":"get_weather","input":{"city":"NYC"}}]`)}, + {Role: "user", Content: json.RawMessage(`[ + {"type":"tool_result","tool_use_id":"call_1","content":[ + {"type":"text","text":"Sunny, 72°F"} + ]} + ]`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + function_call + function_call_output = 3 + require.Len(t, items, 3) + + // Text-only tool_result should produce a plain string. + assert.Equal(t, "Sunny, 72°F", items[2].Output) +} + +func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`[ + {"type":"image","source":{"type":"base64","media_type":"","data":"iVBOR"}} + ]`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "input_image", parts[0].Type) + // Should default to image/png when media_type is empty. + assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL) +} diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses.go b/backend/internal/pkg/apicompat/anthropic_to_responses.go new file mode 100644 index 00000000..0a747869 --- /dev/null +++ b/backend/internal/pkg/apicompat/anthropic_to_responses.go @@ -0,0 +1,416 @@ +package apicompat + +import ( + "encoding/json" + "fmt" + "strings" +) + +// AnthropicToResponses converts an Anthropic Messages request directly into +// a Responses API request. This preserves fields that would be lost in a +// Chat Completions intermediary round-trip (e.g. thinking, cache_control, +// structured system prompts). +func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) { + input, err := convertAnthropicToResponsesInput(req.System, req.Messages) + if err != nil { + return nil, err + } + + inputJSON, err := json.Marshal(input) + if err != nil { + return nil, err + } + + out := &ResponsesRequest{ + Model: req.Model, + Input: inputJSON, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: req.Stream, + Include: []string{"reasoning.encrypted_content"}, + } + + storeFalse := false + out.Store = &storeFalse + + if req.MaxTokens > 0 { + v := req.MaxTokens + if v < minMaxOutputTokens { + v = minMaxOutputTokens + } + out.MaxOutputTokens = &v + } + + if len(req.Tools) > 0 { + out.Tools = convertAnthropicToolsToResponses(req.Tools) + } + + // Determine reasoning effort: only output_config.effort controls the + // level; thinking.type is ignored. Default is xhigh when unset. + // Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh. + effort := "high" // default → maps to xhigh + if req.OutputConfig != nil && req.OutputConfig.Effort != "" { + effort = req.OutputConfig.Effort + } + out.Reasoning = &ResponsesReasoning{ + Effort: mapAnthropicEffortToResponses(effort), + Summary: "auto", + } + + // Convert tool_choice + if len(req.ToolChoice) > 0 { + tc, err := convertAnthropicToolChoiceToResponses(req.ToolChoice) + if err != nil { + return nil, fmt.Errorf("convert tool_choice: %w", err) + } + out.ToolChoice = tc + } + + return out, nil +} + +// convertAnthropicToolChoiceToResponses maps Anthropic tool_choice to Responses format. +// +// {"type":"auto"} → "auto" +// {"type":"any"} → "required" +// {"type":"none"} → "none" +// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}} +func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) { + var tc struct { + Type string `json:"type"` + Name string `json:"name"` + } + if err := json.Unmarshal(raw, &tc); err != nil { + return nil, err + } + + switch tc.Type { + case "auto": + return json.Marshal("auto") + case "any": + return json.Marshal("required") + case "none": + return json.Marshal("none") + case "tool": + return json.Marshal(map[string]any{ + "type": "function", + "function": map[string]string{"name": tc.Name}, + }) + default: + // Pass through unknown types as-is + return raw, nil + } +} + +// convertAnthropicToResponsesInput builds the Responses API input items array +// from the Anthropic system field and message list. +func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMessage) ([]ResponsesInputItem, error) { + var out []ResponsesInputItem + + // System prompt → system role input item. + if len(system) > 0 { + sysText, err := parseAnthropicSystemPrompt(system) + if err != nil { + return nil, err + } + if sysText != "" { + content, _ := json.Marshal(sysText) + out = append(out, ResponsesInputItem{ + Role: "system", + Content: content, + }) + } + } + + for _, m := range msgs { + items, err := anthropicMsgToResponsesItems(m) + if err != nil { + return nil, err + } + out = append(out, items...) + } + return out, nil +} + +// parseAnthropicSystemPrompt handles the Anthropic system field which can be +// a plain string or an array of text blocks. +func parseAnthropicSystemPrompt(raw json.RawMessage) (string, error) { + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s, nil + } + var blocks []AnthropicContentBlock + if err := json.Unmarshal(raw, &blocks); err != nil { + return "", err + } + var parts []string + for _, b := range blocks { + if b.Type == "text" && b.Text != "" { + parts = append(parts, b.Text) + } + } + return strings.Join(parts, "\n\n"), nil +} + +// anthropicMsgToResponsesItems converts a single Anthropic message into one +// or more Responses API input items. +func anthropicMsgToResponsesItems(m AnthropicMessage) ([]ResponsesInputItem, error) { + switch m.Role { + case "user": + return anthropicUserToResponses(m.Content) + case "assistant": + return anthropicAssistantToResponses(m.Content) + default: + return anthropicUserToResponses(m.Content) + } +} + +// anthropicUserToResponses handles an Anthropic user message. Content can be a +// plain string or an array of blocks. tool_result blocks are extracted into +// function_call_output items. Image blocks are converted to input_image parts. +func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) { + // Try plain string. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + content, _ := json.Marshal(s) + return []ResponsesInputItem{{Role: "user", Content: content}}, nil + } + + var blocks []AnthropicContentBlock + if err := json.Unmarshal(raw, &blocks); err != nil { + return nil, err + } + + var out []ResponsesInputItem + var toolResultImageParts []ResponsesContentPart + + // Extract tool_result blocks → function_call_output items. + // Images inside tool_results are extracted separately because the + // Responses API function_call_output.output only accepts strings. + for _, b := range blocks { + if b.Type != "tool_result" { + continue + } + outputText, imageParts := convertToolResultOutput(b) + out = append(out, ResponsesInputItem{ + Type: "function_call_output", + CallID: toResponsesCallID(b.ToolUseID), + Output: outputText, + }) + toolResultImageParts = append(toolResultImageParts, imageParts...) + } + + // Remaining text + image blocks → user message with content parts. + // Also include images extracted from tool_results so the model can see them. + var parts []ResponsesContentPart + for _, b := range blocks { + switch b.Type { + case "text": + if b.Text != "" { + parts = append(parts, ResponsesContentPart{Type: "input_text", Text: b.Text}) + } + case "image": + if uri := anthropicImageToDataURI(b.Source); uri != "" { + parts = append(parts, ResponsesContentPart{Type: "input_image", ImageURL: uri}) + } + } + } + parts = append(parts, toolResultImageParts...) + + if len(parts) > 0 { + content, err := json.Marshal(parts) + if err != nil { + return nil, err + } + out = append(out, ResponsesInputItem{Role: "user", Content: content}) + } + + return out, nil +} + +// anthropicAssistantToResponses handles an Anthropic assistant message. +// Text content → assistant message with output_text parts. +// tool_use blocks → function_call items. +// thinking blocks → ignored (OpenAI doesn't accept them as input). +func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) { + // Try plain string. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + parts := []ResponsesContentPart{{Type: "output_text", Text: s}} + partsJSON, err := json.Marshal(parts) + if err != nil { + return nil, err + } + return []ResponsesInputItem{{Role: "assistant", Content: partsJSON}}, nil + } + + var blocks []AnthropicContentBlock + if err := json.Unmarshal(raw, &blocks); err != nil { + return nil, err + } + + var items []ResponsesInputItem + + // Text content → assistant message with output_text content parts. + text := extractAnthropicTextFromBlocks(blocks) + if text != "" { + parts := []ResponsesContentPart{{Type: "output_text", Text: text}} + partsJSON, err := json.Marshal(parts) + if err != nil { + return nil, err + } + items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON}) + } + + // tool_use → function_call items. + for _, b := range blocks { + if b.Type != "tool_use" { + continue + } + args := "{}" + if len(b.Input) > 0 { + args = string(b.Input) + } + fcID := toResponsesCallID(b.ID) + items = append(items, ResponsesInputItem{ + Type: "function_call", + CallID: fcID, + Name: b.Name, + Arguments: args, + }) + } + + return items, nil +} + +// toResponsesCallID converts an Anthropic tool ID (toolu_xxx / call_xxx) to a +// Responses API function_call ID that starts with "fc_". +func toResponsesCallID(id string) string { + if strings.HasPrefix(id, "fc_") { + return id + } + return "fc_" + id +} + +// fromResponsesCallID reverses toResponsesCallID, stripping the "fc_" prefix +// that was added during request conversion. +func fromResponsesCallID(id string) string { + if after, ok := strings.CutPrefix(id, "fc_"); ok { + // Only strip if the remainder doesn't look like it was already "fc_" prefixed. + // E.g. "fc_toolu_xxx" → "toolu_xxx", "fc_call_xxx" → "call_xxx" + if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") { + return after + } + } + return id +} + +// anthropicImageToDataURI converts an AnthropicImageSource to a data URI string. +// Returns "" if the source is nil or has no data. +func anthropicImageToDataURI(src *AnthropicImageSource) string { + if src == nil || src.Data == "" { + return "" + } + mediaType := src.MediaType + if mediaType == "" { + mediaType = "image/png" + } + return "data:" + mediaType + ";base64," + src.Data +} + +// convertToolResultOutput extracts text and image content from a tool_result +// block. Returns the text as a string for the function_call_output Output +// field, plus any image parts that must be sent in a separate user message +// (the Responses API output field only accepts strings). +func convertToolResultOutput(b AnthropicContentBlock) (string, []ResponsesContentPart) { + if len(b.Content) == 0 { + return "(empty)", nil + } + + // Try plain string content. + var s string + if err := json.Unmarshal(b.Content, &s); err == nil { + if s == "" { + s = "(empty)" + } + return s, nil + } + + // Array of content blocks — may contain text and/or images. + var inner []AnthropicContentBlock + if err := json.Unmarshal(b.Content, &inner); err != nil { + return "(empty)", nil + } + + // Separate text (for function_call_output) from images (for user message). + var textParts []string + var imageParts []ResponsesContentPart + for _, ib := range inner { + switch ib.Type { + case "text": + if ib.Text != "" { + textParts = append(textParts, ib.Text) + } + case "image": + if uri := anthropicImageToDataURI(ib.Source); uri != "" { + imageParts = append(imageParts, ResponsesContentPart{Type: "input_image", ImageURL: uri}) + } + } + } + + text := strings.Join(textParts, "\n\n") + if text == "" { + text = "(empty)" + } + return text, imageParts +} + +// extractAnthropicTextFromBlocks joins all text blocks, ignoring thinking/ +// tool_use/tool_result blocks. +func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string { + var parts []string + for _, b := range blocks { + if b.Type == "text" && b.Text != "" { + parts = append(parts, b.Text) + } + } + return strings.Join(parts, "\n\n") +} + +// mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to +// OpenAI Responses API effort levels. +// +// low → low +// medium → high +// high → xhigh +func mapAnthropicEffortToResponses(effort string) string { + switch effort { + case "medium": + return "high" + case "high": + return "xhigh" + default: + return effort // "low" and any unknown values pass through unchanged + } +} + +// convertAnthropicToolsToResponses maps Anthropic tool definitions to +// Responses API tools. Server-side tools like web_search are mapped to their +// OpenAI equivalents; regular tools become function tools. +func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool { + var out []ResponsesTool + for _, t := range tools { + // Anthropic server tools like "web_search_20250305" → OpenAI {"type":"web_search"} + if strings.HasPrefix(t.Type, "web_search") { + out = append(out, ResponsesTool{Type: "web_search"}) + continue + } + out = append(out, ResponsesTool{ + Type: "function", + Name: t.Name, + Description: t.Description, + Parameters: t.InputSchema, + }) + } + return out +} diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go new file mode 100644 index 00000000..8b819033 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -0,0 +1,810 @@ +package apicompat + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// ChatCompletionsToResponses tests +// --------------------------------------------------------------------------- + +func TestChatCompletionsToResponses_BasicText(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hello"`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + assert.Equal(t, "gpt-4o", resp.Model) + assert.True(t, resp.Stream) // always forced true + assert.False(t, *resp.Store) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + assert.Equal(t, "user", items[0].Role) +} + +func TestChatCompletionsToResponses_SystemMessage(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "system", Content: json.RawMessage(`"You are helpful."`)}, + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + assert.Equal(t, "system", items[0].Role) + assert.Equal(t, "user", items[1].Role) +} + +func TestChatCompletionsToResponses_ToolCalls(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Call the function"`)}, + { + Role: "assistant", + ToolCalls: []ChatToolCall{ + { + ID: "call_1", + Type: "function", + Function: ChatFunctionCall{ + Name: "ping", + Arguments: `{"host":"example.com"}`, + }, + }, + }, + }, + { + Role: "tool", + ToolCallID: "call_1", + Content: json.RawMessage(`"pong"`), + }, + }, + Tools: []ChatTool{ + { + Type: "function", + Function: &ChatFunction{ + Name: "ping", + Description: "Ping a host", + Parameters: json.RawMessage(`{"type":"object"}`), + }, + }, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + function_call + function_call_output = 3 + // (assistant message with empty content + tool_calls → only function_call items emitted) + require.Len(t, items, 3) + + // Check function_call item + assert.Equal(t, "function_call", items[1].Type) + assert.Equal(t, "call_1", items[1].CallID) + assert.Empty(t, items[1].ID) + assert.Equal(t, "ping", items[1].Name) + + // Check function_call_output item + assert.Equal(t, "function_call_output", items[2].Type) + assert.Equal(t, "call_1", items[2].CallID) + assert.Equal(t, "pong", items[2].Output) + + // Check tools + require.Len(t, resp.Tools, 1) + assert.Equal(t, "function", resp.Tools[0].Type) + assert.Equal(t, "ping", resp.Tools[0].Name) +} + +func TestChatCompletionsToResponses_MaxTokens(t *testing.T) { + t.Run("max_tokens", func(t *testing.T) { + maxTokens := 100 + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + MaxTokens: &maxTokens, + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.MaxOutputTokens) + // Below minMaxOutputTokens (128), should be clamped + assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens) + }) + + t.Run("max_completion_tokens_preferred", func(t *testing.T) { + maxTokens := 100 + maxCompletion := 500 + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + MaxTokens: &maxTokens, + MaxCompletionTokens: &maxCompletion, + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.MaxOutputTokens) + assert.Equal(t, 500, *resp.MaxOutputTokens) + }) +} + +func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + ReasoningEffort: "high", + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) +} + +func TestChatCompletionsToResponses_ImageURL(t *testing.T) { + content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]` + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(content)}, + }, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 2) + assert.Equal(t, "input_text", parts[0].Type) + assert.Equal(t, "Describe this", parts[0].Text) + assert.Equal(t, "input_image", parts[1].Type) + assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL) +} + +func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + }, + Functions: []ChatFunction{ + { + Name: "get_weather", + Description: "Get weather", + Parameters: json.RawMessage(`{"type":"object"}`), + }, + }, + FunctionCall: json.RawMessage(`{"name":"get_weather"}`), + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.Len(t, resp.Tools, 1) + assert.Equal(t, "function", resp.Tools[0].Type) + assert.Equal(t, "get_weather", resp.Tools[0].Name) + + // tool_choice should be converted + require.NotNil(t, resp.ToolChoice) + var tc map[string]any + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "function", tc["type"]) +} + +func TestChatCompletionsToResponses_ServiceTier(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + ServiceTier: "flex", + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + assert.Equal(t, "flex", resp.ServiceTier) +} + +func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Do something"`)}, + { + Role: "assistant", + Content: json.RawMessage(`"Let me call a function."`), + ToolCalls: []ChatToolCall{ + { + ID: "call_abc", + Type: "function", + Function: ChatFunctionCall{ + Name: "do_thing", + Arguments: `{}`, + }, + }, + }, + }, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + assistant message (with text) + function_call + require.Len(t, items, 3) + assert.Equal(t, "user", items[0].Role) + assert.Equal(t, "assistant", items[1].Role) + assert.Equal(t, "function_call", items[2].Type) + assert.Empty(t, items[2].ID) +} + +func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + assert.Equal(t, "assistant", items[1].Role) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[1].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "output_text", parts[0].Type) + assert.Equal(t, "AB", parts[0].Text) +} + +func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[1].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "output_text", parts[0].Type) + assert.Contains(t, parts[0].Text, "internal plan") + assert.Contains(t, parts[0].Text, "final answer") +} + +// --------------------------------------------------------------------------- +// ResponsesToChatCompletions tests +// --------------------------------------------------------------------------- + +func TestResponsesToChatCompletions_BasicText(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_123", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "Hello, world!"}, + }, + }, + }, + Usage: &ResponsesUsage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + assert.Equal(t, "chat.completion", chat.Object) + assert.Equal(t, "gpt-4o", chat.Model) + require.Len(t, chat.Choices, 1) + assert.Equal(t, "stop", chat.Choices[0].FinishReason) + + var content string + require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) + assert.Equal(t, "Hello, world!", content) + + require.NotNil(t, chat.Usage) + assert.Equal(t, 10, chat.Usage.PromptTokens) + assert.Equal(t, 5, chat.Usage.CompletionTokens) + assert.Equal(t, 15, chat.Usage.TotalTokens) +} + +func TestResponsesToChatCompletions_ToolCalls(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_456", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "function_call", + CallID: "call_xyz", + Name: "get_weather", + Arguments: `{"city":"NYC"}`, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason) + + msg := chat.Choices[0].Message + require.Len(t, msg.ToolCalls, 1) + assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID) + assert.Equal(t, "function", msg.ToolCalls[0].Type) + assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name) + assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments) +} + +func TestResponsesToChatCompletions_Reasoning(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_789", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "reasoning", + Summary: []ResponsesSummary{ + {Type: "summary_text", Text: "I thought about it."}, + }, + }, + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "The answer is 42."}, + }, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + + var content string + require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) + assert.Equal(t, "The answer is 42.", content) + assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent) +} + +func TestResponsesToChatCompletions_Incomplete(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_inc", + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "partial..."}, + }, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + assert.Equal(t, "length", chat.Choices[0].FinishReason) +} + +func TestResponsesToChatCompletions_CachedTokens(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_cache", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}}, + }, + }, + Usage: &ResponsesUsage{ + InputTokens: 100, + OutputTokens: 10, + TotalTokens: 110, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 80, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.NotNil(t, chat.Usage) + require.NotNil(t, chat.Usage.PromptTokensDetails) + assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens) +} + +func TestResponsesToChatCompletions_WebSearch(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_ws", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "web_search_call", + Action: &WebSearchAction{Type: "search", Query: "test"}, + }, + { + Type: "message", + Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}}, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + assert.Equal(t, "stop", chat.Choices[0].FinishReason) + + var content string + require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) + assert.Equal(t, "search results", content) +} + +// --------------------------------------------------------------------------- +// Streaming: ResponsesEventToChatChunks tests +// --------------------------------------------------------------------------- + +func TestResponsesEventToChatChunks_TextDelta(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + + // response.created → role chunk + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ + ID: "resp_stream", + }, + }, state) + require.Len(t, chunks, 1) + assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role) + assert.True(t, state.SentRole) + + // response.output_text.delta → content chunk + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "Hello", + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].Delta.Content) + assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content) +} + +func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SentRole = true + + // response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0) + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 1, + Item: &ResponsesOutput{ + Type: "function_call", + CallID: "call_1", + Name: "get_weather", + }, + }, state) + require.Len(t, chunks, 1) + require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1) + tc := chunks[0].Choices[0].Delta.ToolCalls[0] + assert.Equal(t, "call_1", tc.ID) + assert.Equal(t, "get_weather", tc.Function.Name) + require.NotNil(t, tc.Index) + assert.Equal(t, 0, *tc.Index) + + // response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 1, // matches the output_index from output_item.added above + Delta: `{"city":`, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call") + assert.Equal(t, `{"city":`, tc.Function.Arguments) + + // Add a second function call at output_index=2 + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 2, + Item: &ResponsesOutput{ + Type: "function_call", + CallID: "call_2", + Name: "get_time", + }, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 1, *tc.Index, "second tool call should get index 1") + + // Argument delta for second tool call + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 2, + Delta: `{"tz":"UTC"}`, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1") + + // Argument delta for first tool call (interleaved) + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 1, + Delta: `"Tokyo"}`, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0") +} + +func TestResponsesEventToChatChunks_Completed(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 50, + OutputTokens: 20, + TotalTokens: 70, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 30, + }, + }, + }, + }, state) + // finish chunk + usage chunk + require.Len(t, chunks, 2) + + // First chunk: finish_reason + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason) + + // Second chunk: usage + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 50, chunks[1].Usage.PromptTokens) + assert.Equal(t, 20, chunks[1].Usage.CompletionTokens) + assert.Equal(t, 70, chunks[1].Usage.TotalTokens) + require.NotNil(t, chunks[1].Usage.PromptTokensDetails) + assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens) +} + +func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SawToolCall = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + }, + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason) +} + +func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SentRole = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.reasoning_summary_text.delta", + Delta: "Thinking...", + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent) + assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent) + + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.reasoning_summary_text.done", + }, state) + require.Len(t, chunks, 0) +} + +func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SentRole = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.reasoning_summary_text.delta", + Delta: "plan", + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent) + assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent) + + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "answer", + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].Delta.Content) + assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content) +} + +func TestFinalizeResponsesChatStream(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + state.Usage = &ChatUsage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + } + + chunks := FinalizeResponsesChatStream(state) + require.Len(t, chunks, 2) + + // Finish chunk + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason) + + // Usage chunk + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 100, chunks[1].Usage.PromptTokens) + + // Idempotent: second call returns nil + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + +func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) { + // If response.completed already emitted the finish chunk, FinalizeResponsesChatStream + // must be a no-op (prevents double finish_reason being sent to the client). + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + // Simulate response.completed + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + }, + }, state) + require.NotEmpty(t, chunks) // finish + usage chunks + + // Now FinalizeResponsesChatStream should return nil — already finalized. + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + +func TestChatChunkToSSE(t *testing.T) { + chunk := ChatCompletionsChunk{ + ID: "chatcmpl-test", + Object: "chat.completion.chunk", + Created: 1700000000, + Model: "gpt-4o", + Choices: []ChatChunkChoice{ + { + Index: 0, + Delta: ChatDelta{Role: "assistant"}, + FinishReason: nil, + }, + }, + } + + sse, err := ChatChunkToSSE(chunk) + require.NoError(t, err) + assert.Contains(t, sse, "data: ") + assert.Contains(t, sse, "chatcmpl-test") + assert.Contains(t, sse, "assistant") + assert.True(t, len(sse) > 10) +} + +// --------------------------------------------------------------------------- +// Stream round-trip test +// --------------------------------------------------------------------------- + +func TestChatCompletionsStreamRoundTrip(t *testing.T) { + // Simulate: client sends chat completions request, upstream returns Responses SSE events. + // Verify that the streaming state machine produces correct chat completions chunks. + + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + var allChunks []ChatCompletionsChunk + + // 1. response.created + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_rt"}, + }, state) + allChunks = append(allChunks, chunks...) + + // 2. text deltas + for _, text := range []string{"Hello", ", ", "world", "!"} { + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: text, + }, state) + allChunks = append(allChunks, chunks...) + } + + // 3. response.completed + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 10, + OutputTokens: 4, + TotalTokens: 14, + }, + }, + }, state) + allChunks = append(allChunks, chunks...) + + // Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7 + require.Len(t, allChunks, 7) + + // First chunk has role + assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role) + + // Text chunks + var fullText string + for i := 1; i <= 4; i++ { + require.NotNil(t, allChunks[i].Choices[0].Delta.Content) + fullText += *allChunks[i].Choices[0].Delta.Content + } + assert.Equal(t, "Hello, world!", fullText) + + // Finish chunk + require.NotNil(t, allChunks[5].Choices[0].FinishReason) + assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason) + + // Usage chunk + require.NotNil(t, allChunks[6].Usage) + assert.Equal(t, 10, allChunks[6].Usage.PromptTokens) + assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens) + + // All chunks share the same ID + for _, c := range allChunks { + assert.Equal(t, "resp_rt", c.ID) + } +} diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go new file mode 100644 index 00000000..c4a9e773 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go @@ -0,0 +1,385 @@ +package apicompat + +import ( + "encoding/json" + "fmt" + "strings" +) + +// ChatCompletionsToResponses converts a Chat Completions request into a +// Responses API request. The upstream always streams, so Stream is forced to +// true. store is always false and reasoning.encrypted_content is always +// included so that the response translator has full context. +func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) { + input, err := convertChatMessagesToResponsesInput(req.Messages) + if err != nil { + return nil, err + } + + inputJSON, err := json.Marshal(input) + if err != nil { + return nil, err + } + + out := &ResponsesRequest{ + Model: req.Model, + Input: inputJSON, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: true, // upstream always streams + Include: []string{"reasoning.encrypted_content"}, + ServiceTier: req.ServiceTier, + } + + storeFalse := false + out.Store = &storeFalse + + // max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens + maxTokens := 0 + if req.MaxTokens != nil { + maxTokens = *req.MaxTokens + } + if req.MaxCompletionTokens != nil { + maxTokens = *req.MaxCompletionTokens + } + if maxTokens > 0 { + v := maxTokens + if v < minMaxOutputTokens { + v = minMaxOutputTokens + } + out.MaxOutputTokens = &v + } + + // reasoning_effort → reasoning.effort + reasoning.summary="auto" + if req.ReasoningEffort != "" { + out.Reasoning = &ResponsesReasoning{ + Effort: req.ReasoningEffort, + Summary: "auto", + } + } + + // tools[] and legacy functions[] → ResponsesTool[] + if len(req.Tools) > 0 || len(req.Functions) > 0 { + out.Tools = convertChatToolsToResponses(req.Tools, req.Functions) + } + + // tool_choice: already compatible format — pass through directly. + // Legacy function_call needs mapping. + if len(req.ToolChoice) > 0 { + out.ToolChoice = req.ToolChoice + } else if len(req.FunctionCall) > 0 { + tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall) + if err != nil { + return nil, fmt.Errorf("convert function_call: %w", err) + } + out.ToolChoice = tc + } + + return out, nil +} + +// convertChatMessagesToResponsesInput converts the Chat Completions messages +// array into a Responses API input items array. +func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) { + var out []ResponsesInputItem + for _, m := range msgs { + items, err := chatMessageToResponsesItems(m) + if err != nil { + return nil, err + } + out = append(out, items...) + } + return out, nil +} + +// chatMessageToResponsesItems converts a single ChatMessage into one or more +// ResponsesInputItem values. +func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) { + switch m.Role { + case "system": + return chatSystemToResponses(m) + case "user": + return chatUserToResponses(m) + case "assistant": + return chatAssistantToResponses(m) + case "tool": + return chatToolToResponses(m) + case "function": + return chatFunctionToResponses(m) + default: + return chatUserToResponses(m) + } +} + +// chatSystemToResponses converts a system message. +func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + text, err := parseChatContent(m.Content) + if err != nil { + return nil, err + } + content, err := json.Marshal(text) + if err != nil { + return nil, err + } + return []ResponsesInputItem{{Role: "system", Content: content}}, nil +} + +// chatUserToResponses converts a user message, handling both plain strings and +// multi-modal content arrays. +func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + // Try plain string first. + var s string + if err := json.Unmarshal(m.Content, &s); err == nil { + content, _ := json.Marshal(s) + return []ResponsesInputItem{{Role: "user", Content: content}}, nil + } + + var parts []ChatContentPart + if err := json.Unmarshal(m.Content, &parts); err != nil { + return nil, fmt.Errorf("parse user content: %w", err) + } + + var responseParts []ResponsesContentPart + for _, p := range parts { + switch p.Type { + case "text": + if p.Text != "" { + responseParts = append(responseParts, ResponsesContentPart{ + Type: "input_text", + Text: p.Text, + }) + } + case "image_url": + if p.ImageURL != nil && p.ImageURL.URL != "" { + responseParts = append(responseParts, ResponsesContentPart{ + Type: "input_image", + ImageURL: p.ImageURL.URL, + }) + } + } + } + + content, err := json.Marshal(responseParts) + if err != nil { + return nil, err + } + return []ResponsesInputItem{{Role: "user", Content: content}}, nil +} + +// chatAssistantToResponses converts an assistant message. If there is both +// text content and tool_calls, the text is emitted as an assistant message +// first, then each tool_call becomes a function_call item. If the content is +// empty/nil and there are tool_calls, only function_call items are emitted. +func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + var items []ResponsesInputItem + + // Emit assistant message with output_text if content is non-empty. + if len(m.Content) > 0 { + s, err := parseAssistantContent(m.Content) + if err != nil { + return nil, err + } + if s != "" { + parts := []ResponsesContentPart{{Type: "output_text", Text: s}} + partsJSON, err := json.Marshal(parts) + if err != nil { + return nil, err + } + items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON}) + } + } + + // Emit one function_call item per tool_call. + for _, tc := range m.ToolCalls { + args := tc.Function.Arguments + if args == "" { + args = "{}" + } + items = append(items, ResponsesInputItem{ + Type: "function_call", + CallID: tc.ID, + Name: tc.Function.Name, + Arguments: args, + }) + } + + return items, nil +} + +// parseAssistantContent returns assistant content as plain text. +// +// Supported formats: +// - JSON string +// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}]) +// +// For structured thinking/reasoning parts, it preserves semantics by wrapping +// the text in explicit tags so downstream can still distinguish it from normal text. +func parseAssistantContent(raw json.RawMessage) (string, error) { + if len(raw) == 0 { + return "", nil + } + + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s, nil + } + + var parts []map[string]any + if err := json.Unmarshal(raw, &parts); err != nil { + // Keep compatibility with prior behavior: unsupported assistant content + // formats are ignored instead of failing the whole request conversion. + return "", nil + } + + var b strings.Builder + write := func(v string) error { + _, err := b.WriteString(v) + return err + } + for _, p := range parts { + typ, _ := p["type"].(string) + text, _ := p["text"].(string) + thinking, _ := p["thinking"].(string) + + switch typ { + case "thinking", "reasoning": + if thinking != "" { + if err := write(""); err != nil { + return "", err + } + if err := write(thinking); err != nil { + return "", err + } + if err := write(""); err != nil { + return "", err + } + } else if text != "" { + if err := write(""); err != nil { + return "", err + } + if err := write(text); err != nil { + return "", err + } + if err := write(""); err != nil { + return "", err + } + } + default: + if text != "" { + if err := write(text); err != nil { + return "", err + } + } + } + } + + return b.String(), nil +} + +// chatToolToResponses converts a tool result message (role=tool) into a +// function_call_output item. +func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + output, err := parseChatContent(m.Content) + if err != nil { + return nil, err + } + if output == "" { + output = "(empty)" + } + return []ResponsesInputItem{{ + Type: "function_call_output", + CallID: m.ToolCallID, + Output: output, + }}, nil +} + +// chatFunctionToResponses converts a legacy function result message +// (role=function) into a function_call_output item. The Name field is used as +// call_id since legacy function calls do not carry a separate call_id. +func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + output, err := parseChatContent(m.Content) + if err != nil { + return nil, err + } + if output == "" { + output = "(empty)" + } + return []ResponsesInputItem{{ + Type: "function_call_output", + CallID: m.Name, + Output: output, + }}, nil +} + +// parseChatContent returns the string value of a ChatMessage Content field. +// Content must be a JSON string. Returns "" if content is null or empty. +func parseChatContent(raw json.RawMessage) (string, error) { + if len(raw) == 0 { + return "", nil + } + var s string + if err := json.Unmarshal(raw, &s); err != nil { + return "", fmt.Errorf("parse content as string: %w", err) + } + return s, nil +} + +// convertChatToolsToResponses maps Chat Completions tool definitions and legacy +// function definitions to Responses API tool definitions. +func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool { + var out []ResponsesTool + + for _, t := range tools { + if t.Type != "function" || t.Function == nil { + continue + } + rt := ResponsesTool{ + Type: "function", + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: t.Function.Parameters, + Strict: t.Function.Strict, + } + out = append(out, rt) + } + + // Legacy functions[] are treated as function-type tools. + for _, f := range functions { + rt := ResponsesTool{ + Type: "function", + Name: f.Name, + Description: f.Description, + Parameters: f.Parameters, + Strict: f.Strict, + } + out = append(out, rt) + } + + return out +} + +// convertChatFunctionCallToToolChoice maps the legacy function_call field to a +// Responses API tool_choice value. +// +// "auto" → "auto" +// "none" → "none" +// {"name":"X"} → {"type":"function","function":{"name":"X"}} +func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) { + // Try string first ("auto", "none", etc.) — pass through as-is. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return json.Marshal(s) + } + + // Object form: {"name":"X"} + var obj struct { + Name string `json:"name"` + } + if err := json.Unmarshal(raw, &obj); err != nil { + return nil, err + } + return json.Marshal(map[string]any{ + "type": "function", + "function": map[string]string{"name": obj.Name}, + }) +} diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic.go b/backend/internal/pkg/apicompat/responses_to_anthropic.go new file mode 100644 index 00000000..5409a0f4 --- /dev/null +++ b/backend/internal/pkg/apicompat/responses_to_anthropic.go @@ -0,0 +1,516 @@ +package apicompat + +import ( + "encoding/json" + "fmt" + "time" +) + +// --------------------------------------------------------------------------- +// Non-streaming: ResponsesResponse → AnthropicResponse +// --------------------------------------------------------------------------- + +// ResponsesToAnthropic converts a Responses API response directly into an +// Anthropic Messages response. Reasoning output items are mapped to thinking +// blocks; function_call items become tool_use blocks. +func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicResponse { + out := &AnthropicResponse{ + ID: resp.ID, + Type: "message", + Role: "assistant", + Model: model, + } + + var blocks []AnthropicContentBlock + + for _, item := range resp.Output { + switch item.Type { + case "reasoning": + summaryText := "" + for _, s := range item.Summary { + if s.Type == "summary_text" && s.Text != "" { + summaryText += s.Text + } + } + if summaryText != "" { + blocks = append(blocks, AnthropicContentBlock{ + Type: "thinking", + Thinking: summaryText, + }) + } + case "message": + for _, part := range item.Content { + if part.Type == "output_text" && part.Text != "" { + blocks = append(blocks, AnthropicContentBlock{ + Type: "text", + Text: part.Text, + }) + } + } + case "function_call": + blocks = append(blocks, AnthropicContentBlock{ + Type: "tool_use", + ID: fromResponsesCallID(item.CallID), + Name: item.Name, + Input: json.RawMessage(item.Arguments), + }) + case "web_search_call": + toolUseID := "srvtoolu_" + item.ID + query := "" + if item.Action != nil { + query = item.Action.Query + } + inputJSON, _ := json.Marshal(map[string]string{"query": query}) + blocks = append(blocks, AnthropicContentBlock{ + Type: "server_tool_use", + ID: toolUseID, + Name: "web_search", + Input: inputJSON, + }) + emptyResults, _ := json.Marshal([]struct{}{}) + blocks = append(blocks, AnthropicContentBlock{ + Type: "web_search_tool_result", + ToolUseID: toolUseID, + Content: emptyResults, + }) + } + } + + if len(blocks) == 0 { + blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""}) + } + out.Content = blocks + + out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks) + + if resp.Usage != nil { + out.Usage = AnthropicUsage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + } + if resp.Usage.InputTokensDetails != nil { + out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens + } + } + + return out +} + +func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string { + switch status { + case "incomplete": + if details != nil && details.Reason == "max_output_tokens" { + return "max_tokens" + } + return "end_turn" + case "completed": + if len(blocks) > 0 && blocks[len(blocks)-1].Type == "tool_use" { + return "tool_use" + } + return "end_turn" + default: + return "end_turn" + } +} + +// --------------------------------------------------------------------------- +// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter) +// --------------------------------------------------------------------------- + +// ResponsesEventToAnthropicState tracks state for converting a sequence of +// Responses SSE events directly into Anthropic SSE events. +type ResponsesEventToAnthropicState struct { + MessageStartSent bool + MessageStopSent bool + + ContentBlockIndex int + ContentBlockOpen bool + CurrentBlockType string // "text" | "thinking" | "tool_use" + + // OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index. + OutputIndexToBlockIdx map[int]int + + InputTokens int + OutputTokens int + CacheReadInputTokens int + + ResponseID string + Model string + Created int64 +} + +// NewResponsesEventToAnthropicState returns an initialised stream state. +func NewResponsesEventToAnthropicState() *ResponsesEventToAnthropicState { + return &ResponsesEventToAnthropicState{ + OutputIndexToBlockIdx: make(map[int]int), + Created: time.Now().Unix(), + } +} + +// ResponsesEventToAnthropicEvents converts a single Responses SSE event into +// zero or more Anthropic SSE events, updating state as it goes. +func ResponsesEventToAnthropicEvents( + evt *ResponsesStreamEvent, + state *ResponsesEventToAnthropicState, +) []AnthropicStreamEvent { + switch evt.Type { + case "response.created": + return resToAnthHandleCreated(evt, state) + case "response.output_item.added": + return resToAnthHandleOutputItemAdded(evt, state) + case "response.output_text.delta": + return resToAnthHandleTextDelta(evt, state) + case "response.output_text.done": + return resToAnthHandleBlockDone(state) + case "response.function_call_arguments.delta": + return resToAnthHandleFuncArgsDelta(evt, state) + case "response.function_call_arguments.done": + return resToAnthHandleBlockDone(state) + case "response.output_item.done": + return resToAnthHandleOutputItemDone(evt, state) + case "response.reasoning_summary_text.delta": + return resToAnthHandleReasoningDelta(evt, state) + case "response.reasoning_summary_text.done": + return resToAnthHandleBlockDone(state) + case "response.completed", "response.incomplete", "response.failed": + return resToAnthHandleCompleted(evt, state) + default: + return nil + } +} + +// FinalizeResponsesAnthropicStream emits synthetic termination events if the +// stream ended without a proper completion event. +func FinalizeResponsesAnthropicStream(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if !state.MessageStartSent || state.MessageStopSent { + return nil + } + + var events []AnthropicStreamEvent + events = append(events, closeCurrentBlock(state)...) + + events = append(events, + AnthropicStreamEvent{ + Type: "message_delta", + Delta: &AnthropicDelta{ + StopReason: "end_turn", + }, + Usage: &AnthropicUsage{ + InputTokens: state.InputTokens, + OutputTokens: state.OutputTokens, + CacheReadInputTokens: state.CacheReadInputTokens, + }, + }, + AnthropicStreamEvent{Type: "message_stop"}, + ) + state.MessageStopSent = true + return events +} + +// ResponsesAnthropicEventToSSE formats an AnthropicStreamEvent as an SSE line pair. +func ResponsesAnthropicEventToSSE(evt AnthropicStreamEvent) (string, error) { + data, err := json.Marshal(evt) + if err != nil { + return "", err + } + return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil +} + +// --- internal handlers --- + +func resToAnthHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if evt.Response != nil { + state.ResponseID = evt.Response.ID + // Only use upstream model if no override was set (e.g. originalModel) + if state.Model == "" { + state.Model = evt.Response.Model + } + } + + if state.MessageStartSent { + return nil + } + state.MessageStartSent = true + + return []AnthropicStreamEvent{{ + Type: "message_start", + Message: &AnthropicResponse{ + ID: state.ResponseID, + Type: "message", + Role: "assistant", + Content: []AnthropicContentBlock{}, + Model: state.Model, + Usage: AnthropicUsage{ + InputTokens: 0, + OutputTokens: 0, + }, + }, + }} +} + +func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if evt.Item == nil { + return nil + } + + switch evt.Item.Type { + case "function_call": + var events []AnthropicStreamEvent + events = append(events, closeCurrentBlock(state)...) + + idx := state.ContentBlockIndex + state.OutputIndexToBlockIdx[evt.OutputIndex] = idx + state.ContentBlockOpen = true + state.CurrentBlockType = "tool_use" + + events = append(events, AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx, + ContentBlock: &AnthropicContentBlock{ + Type: "tool_use", + ID: fromResponsesCallID(evt.Item.CallID), + Name: evt.Item.Name, + Input: json.RawMessage("{}"), + }, + }) + return events + + case "reasoning": + var events []AnthropicStreamEvent + events = append(events, closeCurrentBlock(state)...) + + idx := state.ContentBlockIndex + state.OutputIndexToBlockIdx[evt.OutputIndex] = idx + state.ContentBlockOpen = true + state.CurrentBlockType = "thinking" + + events = append(events, AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx, + ContentBlock: &AnthropicContentBlock{ + Type: "thinking", + Thinking: "", + }, + }) + return events + + case "message": + return nil + } + + return nil +} + +func resToAnthHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if evt.Delta == "" { + return nil + } + + var events []AnthropicStreamEvent + + if !state.ContentBlockOpen || state.CurrentBlockType != "text" { + events = append(events, closeCurrentBlock(state)...) + + idx := state.ContentBlockIndex + state.ContentBlockOpen = true + state.CurrentBlockType = "text" + + events = append(events, AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx, + ContentBlock: &AnthropicContentBlock{ + Type: "text", + Text: "", + }, + }) + } + + idx := state.ContentBlockIndex + events = append(events, AnthropicStreamEvent{ + Type: "content_block_delta", + Index: &idx, + Delta: &AnthropicDelta{ + Type: "text_delta", + Text: evt.Delta, + }, + }) + return events +} + +func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if evt.Delta == "" { + return nil + } + + blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex] + if !ok { + return nil + } + + return []AnthropicStreamEvent{{ + Type: "content_block_delta", + Index: &blockIdx, + Delta: &AnthropicDelta{ + Type: "input_json_delta", + PartialJSON: evt.Delta, + }, + }} +} + +func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if evt.Delta == "" { + return nil + } + + blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex] + if !ok { + return nil + } + + return []AnthropicStreamEvent{{ + Type: "content_block_delta", + Index: &blockIdx, + Delta: &AnthropicDelta{ + Type: "thinking_delta", + Thinking: evt.Delta, + }, + }} +} + +func resToAnthHandleBlockDone(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if !state.ContentBlockOpen { + return nil + } + return closeCurrentBlock(state) +} + +func resToAnthHandleOutputItemDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if evt.Item == nil { + return nil + } + + // Handle web_search_call → synthesize server_tool_use + web_search_tool_result blocks. + if evt.Item.Type == "web_search_call" && evt.Item.Status == "completed" { + return resToAnthHandleWebSearchDone(evt, state) + } + + if state.ContentBlockOpen { + return closeCurrentBlock(state) + } + return nil +} + +// resToAnthHandleWebSearchDone converts an OpenAI web_search_call output item +// into Anthropic server_tool_use + web_search_tool_result content block pairs. +// This allows Claude Code to count the searches performed. +func resToAnthHandleWebSearchDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + var events []AnthropicStreamEvent + events = append(events, closeCurrentBlock(state)...) + + toolUseID := "srvtoolu_" + evt.Item.ID + query := "" + if evt.Item.Action != nil { + query = evt.Item.Action.Query + } + inputJSON, _ := json.Marshal(map[string]string{"query": query}) + + // Emit server_tool_use block (start + stop). + idx1 := state.ContentBlockIndex + events = append(events, AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx1, + ContentBlock: &AnthropicContentBlock{ + Type: "server_tool_use", + ID: toolUseID, + Name: "web_search", + Input: inputJSON, + }, + }) + events = append(events, AnthropicStreamEvent{ + Type: "content_block_stop", + Index: &idx1, + }) + state.ContentBlockIndex++ + + // Emit web_search_tool_result block (start + stop). + // Content is empty because OpenAI does not expose individual search results; + // the model consumes them internally and produces text output. + emptyResults, _ := json.Marshal([]struct{}{}) + idx2 := state.ContentBlockIndex + events = append(events, AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx2, + ContentBlock: &AnthropicContentBlock{ + Type: "web_search_tool_result", + ToolUseID: toolUseID, + Content: emptyResults, + }, + }) + events = append(events, AnthropicStreamEvent{ + Type: "content_block_stop", + Index: &idx2, + }) + state.ContentBlockIndex++ + + return events +} + +func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if state.MessageStopSent { + return nil + } + + var events []AnthropicStreamEvent + events = append(events, closeCurrentBlock(state)...) + + stopReason := "end_turn" + if evt.Response != nil { + if evt.Response.Usage != nil { + state.InputTokens = evt.Response.Usage.InputTokens + state.OutputTokens = evt.Response.Usage.OutputTokens + if evt.Response.Usage.InputTokensDetails != nil { + state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens + } + } + switch evt.Response.Status { + case "incomplete": + if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" { + stopReason = "max_tokens" + } + case "completed": + if state.ContentBlockIndex > 0 && state.CurrentBlockType == "tool_use" { + stopReason = "tool_use" + } + } + } + + events = append(events, + AnthropicStreamEvent{ + Type: "message_delta", + Delta: &AnthropicDelta{ + StopReason: stopReason, + }, + Usage: &AnthropicUsage{ + InputTokens: state.InputTokens, + OutputTokens: state.OutputTokens, + CacheReadInputTokens: state.CacheReadInputTokens, + }, + }, + AnthropicStreamEvent{Type: "message_stop"}, + ) + state.MessageStopSent = true + return events +} + +func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if !state.ContentBlockOpen { + return nil + } + idx := state.ContentBlockIndex + state.ContentBlockOpen = false + state.ContentBlockIndex++ + return []AnthropicStreamEvent{{ + Type: "content_block_stop", + Index: &idx, + }} +} diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go new file mode 100644 index 00000000..688a68eb --- /dev/null +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go @@ -0,0 +1,374 @@ +package apicompat + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "time" +) + +// --------------------------------------------------------------------------- +// Non-streaming: ResponsesResponse → ChatCompletionsResponse +// --------------------------------------------------------------------------- + +// ResponsesToChatCompletions converts a Responses API response into a Chat +// Completions response. Text output items are concatenated into +// choices[0].message.content; function_call items become tool_calls. +func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse { + id := resp.ID + if id == "" { + id = generateChatCmplID() + } + + out := &ChatCompletionsResponse{ + ID: id, + Object: "chat.completion", + Created: time.Now().Unix(), + Model: model, + } + + var contentText string + var reasoningText string + var toolCalls []ChatToolCall + + for _, item := range resp.Output { + switch item.Type { + case "message": + for _, part := range item.Content { + if part.Type == "output_text" && part.Text != "" { + contentText += part.Text + } + } + case "function_call": + toolCalls = append(toolCalls, ChatToolCall{ + ID: item.CallID, + Type: "function", + Function: ChatFunctionCall{ + Name: item.Name, + Arguments: item.Arguments, + }, + }) + case "reasoning": + for _, s := range item.Summary { + if s.Type == "summary_text" && s.Text != "" { + reasoningText += s.Text + } + } + case "web_search_call": + // silently consumed — results already incorporated into text output + } + } + + msg := ChatMessage{Role: "assistant"} + if len(toolCalls) > 0 { + msg.ToolCalls = toolCalls + } + if contentText != "" { + raw, _ := json.Marshal(contentText) + msg.Content = raw + } + if reasoningText != "" { + msg.ReasoningContent = reasoningText + } + + finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls) + + out.Choices = []ChatChoice{{ + Index: 0, + Message: msg, + FinishReason: finishReason, + }} + + if resp.Usage != nil { + usage := &ChatUsage{ + PromptTokens: resp.Usage.InputTokens, + CompletionTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + } + if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails = &ChatTokenDetails{ + CachedTokens: resp.Usage.InputTokensDetails.CachedTokens, + } + } + out.Usage = usage + } + + return out +} + +func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string { + switch status { + case "incomplete": + if details != nil && details.Reason == "max_output_tokens" { + return "length" + } + return "stop" + case "completed": + if len(toolCalls) > 0 { + return "tool_calls" + } + return "stop" + default: + return "stop" + } +} + +// --------------------------------------------------------------------------- +// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter) +// --------------------------------------------------------------------------- + +// ResponsesEventToChatState tracks state for converting a sequence of Responses +// SSE events into Chat Completions SSE chunks. +type ResponsesEventToChatState struct { + ID string + Model string + Created int64 + SentRole bool + SawToolCall bool + SawText bool + Finalized bool // true after finish chunk has been emitted + NextToolCallIndex int // next sequential tool_call index to assign + OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index + IncludeUsage bool + Usage *ChatUsage +} + +// NewResponsesEventToChatState returns an initialised stream state. +func NewResponsesEventToChatState() *ResponsesEventToChatState { + return &ResponsesEventToChatState{ + ID: generateChatCmplID(), + Created: time.Now().Unix(), + OutputIndexToToolIndex: make(map[int]int), + } +} + +// ResponsesEventToChatChunks converts a single Responses SSE event into zero +// or more Chat Completions chunks, updating state as it goes. +func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + switch evt.Type { + case "response.created": + return resToChatHandleCreated(evt, state) + case "response.output_text.delta": + return resToChatHandleTextDelta(evt, state) + case "response.output_item.added": + return resToChatHandleOutputItemAdded(evt, state) + case "response.function_call_arguments.delta": + return resToChatHandleFuncArgsDelta(evt, state) + case "response.reasoning_summary_text.delta": + return resToChatHandleReasoningDelta(evt, state) + case "response.reasoning_summary_text.done": + return nil + case "response.completed", "response.incomplete", "response.failed": + return resToChatHandleCompleted(evt, state) + default: + return nil + } +} + +// FinalizeResponsesChatStream emits a final chunk with finish_reason if the +// stream ended without a proper completion event (e.g. upstream disconnect). +// It is idempotent: if a completion event already emitted the finish chunk, +// this returns nil. +func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk { + if state.Finalized { + return nil + } + state.Finalized = true + + finishReason := "stop" + if state.SawToolCall { + finishReason = "tool_calls" + } + + chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)} + + if state.IncludeUsage && state.Usage != nil { + chunks = append(chunks, ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{}, + Usage: state.Usage, + }) + } + + return chunks +} + +// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line. +func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) { + data, err := json.Marshal(chunk) + if err != nil { + return "", err + } + return fmt.Sprintf("data: %s\n\n", data), nil +} + +// --- internal handlers --- + +func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Response != nil { + if evt.Response.ID != "" { + state.ID = evt.Response.ID + } + if state.Model == "" && evt.Response.Model != "" { + state.Model = evt.Response.Model + } + } + // Emit the role chunk. + if state.SentRole { + return nil + } + state.SentRole = true + + role := "assistant" + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})} +} + +func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Delta == "" { + return nil + } + state.SawText = true + content := evt.Delta + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})} +} + +func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Item == nil || evt.Item.Type != "function_call" { + return nil + } + + state.SawToolCall = true + idx := state.NextToolCallIndex + state.OutputIndexToToolIndex[evt.OutputIndex] = idx + state.NextToolCallIndex++ + + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ + ToolCalls: []ChatToolCall{{ + Index: &idx, + ID: evt.Item.CallID, + Type: "function", + Function: ChatFunctionCall{ + Name: evt.Item.Name, + }, + }}, + })} +} + +func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Delta == "" { + return nil + } + + idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex] + if !ok { + return nil + } + + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ + ToolCalls: []ChatToolCall{{ + Index: &idx, + Function: ChatFunctionCall{ + Arguments: evt.Delta, + }, + }}, + })} +} + +func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Delta == "" { + return nil + } + reasoning := evt.Delta + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})} +} + +func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + state.Finalized = true + finishReason := "stop" + + if evt.Response != nil { + if evt.Response.Usage != nil { + u := evt.Response.Usage + usage := &ChatUsage{ + PromptTokens: u.InputTokens, + CompletionTokens: u.OutputTokens, + TotalTokens: u.InputTokens + u.OutputTokens, + } + if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails = &ChatTokenDetails{ + CachedTokens: u.InputTokensDetails.CachedTokens, + } + } + state.Usage = usage + } + + switch evt.Response.Status { + case "incomplete": + if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" { + finishReason = "length" + } + case "completed": + if state.SawToolCall { + finishReason = "tool_calls" + } + } + } else if state.SawToolCall { + finishReason = "tool_calls" + } + + var chunks []ChatCompletionsChunk + chunks = append(chunks, makeChatFinishChunk(state, finishReason)) + + if state.IncludeUsage && state.Usage != nil { + chunks = append(chunks, ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{}, + Usage: state.Usage, + }) + } + + return chunks +} + +func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk { + return ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: delta, + FinishReason: nil, + }}, + } +} + +func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk { + empty := "" + return ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: ChatDelta{Content: &empty}, + FinishReason: &finishReason, + }}, + } +} + +// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID. +func generateChatCmplID() string { + b := make([]byte, 12) + _, _ = rand.Read(b) + return "chatcmpl-" + hex.EncodeToString(b) +} diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go new file mode 100644 index 00000000..b724a5ed --- /dev/null +++ b/backend/internal/pkg/apicompat/types.go @@ -0,0 +1,482 @@ +// Package apicompat provides type definitions and conversion utilities for +// translating between Anthropic Messages and OpenAI Responses API formats. +// It enables multi-protocol support so that clients using different API +// formats can be served through a unified gateway. +package apicompat + +import "encoding/json" + +// --------------------------------------------------------------------------- +// Anthropic Messages API types +// --------------------------------------------------------------------------- + +// AnthropicRequest is the request body for POST /v1/messages. +type AnthropicRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock + Messages []AnthropicMessage `json:"messages"` + Tools []AnthropicTool `json:"tools,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + StopSeqs []string `json:"stop_sequences,omitempty"` + Thinking *AnthropicThinking `json:"thinking,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"` +} + +// AnthropicOutputConfig controls output generation parameters. +type AnthropicOutputConfig struct { + Effort string `json:"effort,omitempty"` // "low" | "medium" | "high" +} + +// AnthropicThinking configures extended thinking in the Anthropic API. +type AnthropicThinking struct { + Type string `json:"type"` // "enabled" | "adaptive" | "disabled" + BudgetTokens int `json:"budget_tokens,omitempty"` // max thinking tokens +} + +// AnthropicMessage is a single message in the Anthropic conversation. +type AnthropicMessage struct { + Role string `json:"role"` // "user" | "assistant" + Content json.RawMessage `json:"content"` +} + +// AnthropicContentBlock is one block inside a message's content array. +type AnthropicContentBlock struct { + Type string `json:"type"` + + // type=text + Text string `json:"text,omitempty"` + + // type=thinking + Thinking string `json:"thinking,omitempty"` + + // type=image + Source *AnthropicImageSource `json:"source,omitempty"` + + // type=tool_use + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input json.RawMessage `json:"input,omitempty"` + + // type=tool_result + ToolUseID string `json:"tool_use_id,omitempty"` + Content json.RawMessage `json:"content,omitempty"` // string or []AnthropicContentBlock + IsError bool `json:"is_error,omitempty"` +} + +// AnthropicImageSource describes the source data for an image content block. +type AnthropicImageSource struct { + Type string `json:"type"` // "base64" + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +// AnthropicTool describes a tool available to the model. +type AnthropicTool struct { + Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object +} + +// AnthropicResponse is the non-streaming response from POST /v1/messages. +type AnthropicResponse struct { + ID string `json:"id"` + Type string `json:"type"` // "message" + Role string `json:"role"` // "assistant" + Content []AnthropicContentBlock `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence,omitempty"` + Usage AnthropicUsage `json:"usage"` +} + +// AnthropicUsage holds token counts in Anthropic format. +type AnthropicUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` +} + +// --------------------------------------------------------------------------- +// Anthropic SSE event types +// --------------------------------------------------------------------------- + +// AnthropicStreamEvent is a single SSE event in the Anthropic streaming protocol. +type AnthropicStreamEvent struct { + Type string `json:"type"` + + // message_start + Message *AnthropicResponse `json:"message,omitempty"` + + // content_block_start + Index *int `json:"index,omitempty"` + ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"` + + // content_block_delta + Delta *AnthropicDelta `json:"delta,omitempty"` + + // message_delta + Usage *AnthropicUsage `json:"usage,omitempty"` +} + +// AnthropicDelta carries incremental content in streaming events. +type AnthropicDelta struct { + Type string `json:"type,omitempty"` // "text_delta" | "input_json_delta" | "thinking_delta" | "signature_delta" + + // text_delta + Text string `json:"text,omitempty"` + + // input_json_delta + PartialJSON string `json:"partial_json,omitempty"` + + // thinking_delta + Thinking string `json:"thinking,omitempty"` + + // signature_delta + Signature string `json:"signature,omitempty"` + + // message_delta fields + StopReason string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` +} + +// --------------------------------------------------------------------------- +// OpenAI Responses API types +// --------------------------------------------------------------------------- + +// ResponsesRequest is the request body for POST /v1/responses. +type ResponsesRequest struct { + Model string `json:"model"` + Input json.RawMessage `json:"input"` // string or []ResponsesInputItem + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + Tools []ResponsesTool `json:"tools,omitempty"` + Include []string `json:"include,omitempty"` + Store *bool `json:"store,omitempty"` + Reasoning *ResponsesReasoning `json:"reasoning,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` +} + +// ResponsesReasoning configures reasoning effort in the Responses API. +type ResponsesReasoning struct { + Effort string `json:"effort"` // "low" | "medium" | "high" + Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed" +} + +// ResponsesInputItem is one item in the Responses API input array. +// The Type field determines which other fields are populated. +type ResponsesInputItem struct { + // Common + Type string `json:"type,omitempty"` // "" for role-based messages + + // Role-based messages (system/user/assistant) + Role string `json:"role,omitempty"` + Content json.RawMessage `json:"content,omitempty"` // string or []ResponsesContentPart + + // type=function_call + CallID string `json:"call_id,omitempty"` + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + ID string `json:"id,omitempty"` + + // type=function_call_output + Output string `json:"output,omitempty"` +} + +// ResponsesContentPart is a typed content part in a Responses message. +type ResponsesContentPart struct { + Type string `json:"type"` // "input_text" | "output_text" | "input_image" + Text string `json:"text,omitempty"` + ImageURL string `json:"image_url,omitempty"` // data URI for input_image +} + +// ResponsesTool describes a tool in the Responses API. +type ResponsesTool struct { + Type string `json:"type"` // "function" | "web_search" | "local_shell" etc. + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +// ResponsesResponse is the non-streaming response from POST /v1/responses. +type ResponsesResponse struct { + ID string `json:"id"` + Object string `json:"object"` // "response" + Model string `json:"model"` + Status string `json:"status"` // "completed" | "incomplete" | "failed" + Output []ResponsesOutput `json:"output"` + Usage *ResponsesUsage `json:"usage,omitempty"` + + // incomplete_details is present when status="incomplete" + IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details,omitempty"` + + // Error is present when status="failed" + Error *ResponsesError `json:"error,omitempty"` +} + +// ResponsesError describes an error in a failed response. +type ResponsesError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +// ResponsesIncompleteDetails explains why a response is incomplete. +type ResponsesIncompleteDetails struct { + Reason string `json:"reason"` // "max_output_tokens" | "content_filter" +} + +// ResponsesOutput is one output item in a Responses API response. +type ResponsesOutput struct { + Type string `json:"type"` // "message" | "reasoning" | "function_call" | "web_search_call" + + // type=message + ID string `json:"id,omitempty"` + Role string `json:"role,omitempty"` + Content []ResponsesContentPart `json:"content,omitempty"` + Status string `json:"status,omitempty"` + + // type=reasoning + EncryptedContent string `json:"encrypted_content,omitempty"` + Summary []ResponsesSummary `json:"summary,omitempty"` + + // type=function_call + CallID string `json:"call_id,omitempty"` + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + + // type=web_search_call + Action *WebSearchAction `json:"action,omitempty"` +} + +// WebSearchAction describes the search action in a web_search_call output item. +type WebSearchAction struct { + Type string `json:"type,omitempty"` // "search" + Query string `json:"query,omitempty"` // primary search query +} + +// ResponsesSummary is a summary text block inside a reasoning output. +type ResponsesSummary struct { + Type string `json:"type"` // "summary_text" + Text string `json:"text"` +} + +// ResponsesUsage holds token counts in Responses API format. +type ResponsesUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + + // Optional detailed breakdown + InputTokensDetails *ResponsesInputTokensDetails `json:"input_tokens_details,omitempty"` + OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"` +} + +// ResponsesInputTokensDetails breaks down input token usage. +type ResponsesInputTokensDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` +} + +// ResponsesOutputTokensDetails breaks down output token usage. +type ResponsesOutputTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` +} + +// --------------------------------------------------------------------------- +// Responses SSE event types +// --------------------------------------------------------------------------- + +// ResponsesStreamEvent is a single SSE event in the Responses streaming protocol. +// The Type field corresponds to the "type" in the JSON payload. +type ResponsesStreamEvent struct { + Type string `json:"type"` + + // response.created / response.completed / response.failed / response.incomplete + Response *ResponsesResponse `json:"response,omitempty"` + + // response.output_item.added / response.output_item.done + Item *ResponsesOutput `json:"item,omitempty"` + + // response.output_text.delta / response.output_text.done + OutputIndex int `json:"output_index,omitempty"` + ContentIndex int `json:"content_index,omitempty"` + Delta string `json:"delta,omitempty"` + Text string `json:"text,omitempty"` + ItemID string `json:"item_id,omitempty"` + + // response.function_call_arguments.delta / done + CallID string `json:"call_id,omitempty"` + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + + // response.reasoning_summary_text.delta / done + // Reuses Text/Delta fields above, SummaryIndex identifies which summary part + SummaryIndex int `json:"summary_index,omitempty"` + + // error event fields + Code string `json:"code,omitempty"` + Param string `json:"param,omitempty"` + + // Sequence number for ordering events + SequenceNumber int `json:"sequence_number,omitempty"` +} + +// --------------------------------------------------------------------------- +// OpenAI Chat Completions API types +// --------------------------------------------------------------------------- + +// ChatCompletionsRequest is the request body for POST /v1/chat/completions. +type ChatCompletionsRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + MaxTokens *int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"` + Tools []ChatTool `json:"tools,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" + ServiceTier string `json:"service_tier,omitempty"` + Stop json.RawMessage `json:"stop,omitempty"` // string or []string + + // Legacy function calling (deprecated but still supported) + Functions []ChatFunction `json:"functions,omitempty"` + FunctionCall json.RawMessage `json:"function_call,omitempty"` +} + +// ChatStreamOptions configures streaming behavior. +type ChatStreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` +} + +// ChatMessage is a single message in the Chat Completions conversation. +type ChatMessage struct { + Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function" + Content json.RawMessage `json:"content,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + Name string `json:"name,omitempty"` + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + + // Legacy function calling + FunctionCall *ChatFunctionCall `json:"function_call,omitempty"` +} + +// ChatContentPart is a typed content part in a multi-modal message. +type ChatContentPart struct { + Type string `json:"type"` // "text" | "image_url" + Text string `json:"text,omitempty"` + ImageURL *ChatImageURL `json:"image_url,omitempty"` +} + +// ChatImageURL contains the URL for an image content part. +type ChatImageURL struct { + URL string `json:"url"` + Detail string `json:"detail,omitempty"` // "auto" | "low" | "high" +} + +// ChatTool describes a tool available to the model. +type ChatTool struct { + Type string `json:"type"` // "function" + Function *ChatFunction `json:"function,omitempty"` +} + +// ChatFunction describes a function tool definition. +type ChatFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +// ChatToolCall represents a tool call made by the assistant. +// Index is only populated in streaming chunks (omitted in non-streaming responses). +type ChatToolCall struct { + Index *int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` // "function" + Function ChatFunctionCall `json:"function"` +} + +// ChatFunctionCall contains the function name and arguments. +type ChatFunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions. +type ChatCompletionsResponse struct { + ID string `json:"id"` + Object string `json:"object"` // "chat.completion" + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChoice `json:"choices"` + Usage *ChatUsage `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` +} + +// ChatChoice is a single completion choice. +type ChatChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter" +} + +// ChatUsage holds token counts in Chat Completions format. +type ChatUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"` +} + +// ChatTokenDetails provides a breakdown of token usage. +type ChatTokenDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` +} + +// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions. +type ChatCompletionsChunk struct { + ID string `json:"id"` + Object string `json:"object"` // "chat.completion.chunk" + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChunkChoice `json:"choices"` + Usage *ChatUsage `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` +} + +// ChatChunkChoice is a single choice in a streaming chunk. +type ChatChunkChoice struct { + Index int `json:"index"` + Delta ChatDelta `json:"delta"` + FinishReason *string `json:"finish_reason"` // pointer: null when not final +} + +// ChatDelta carries incremental content in a streaming chunk. +type ChatDelta struct { + Role string `json:"role,omitempty"` + Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters + ReasoningContent *string `json:"reasoning_content,omitempty"` + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` +} + +// --------------------------------------------------------------------------- +// Shared constants +// --------------------------------------------------------------------------- + +// minMaxOutputTokens is the floor for max_output_tokens in a Responses request. +// Very small values may cause upstream API errors, so we enforce a minimum. +const minMaxOutputTokens = 128 diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index 22405382..dfca252f 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -16,7 +16,7 @@ const ( // DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。 // 这些 token 是客户端特有的,不应透传给上游 API。 -var DroppedBetas = []string{BetaContext1M, BetaFastMode} +var DroppedBetas = []string{} // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go index c300b17d..882d2ebd 100644 --- a/backend/internal/pkg/gemini/models.go +++ b/backend/internal/pkg/gemini/models.go @@ -18,10 +18,12 @@ func DefaultModels() []Model { return []Model{ {Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods}, {Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods}, + {Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods}, {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods}, } } diff --git a/backend/internal/pkg/gemini/models_test.go b/backend/internal/pkg/gemini/models_test.go new file mode 100644 index 00000000..b80047fb --- /dev/null +++ b/backend/internal/pkg/gemini/models_test.go @@ -0,0 +1,28 @@ +package gemini + +import "testing" + +func TestDefaultModels_ContainsImageModels(t *testing.T) { + t.Parallel() + + models := DefaultModels() + byName := make(map[string]Model, len(models)) + for _, model := range models { + byName[model.Name] = model + } + + required := []string{ + "models/gemini-2.5-flash-image", + "models/gemini-3.1-flash-image", + } + + for _, name := range required { + model, ok := byName[name] + if !ok { + t.Fatalf("expected fallback model %q to exist", name) + } + if len(model.SupportedGenerationMethods) == 0 { + t.Fatalf("expected fallback model %q to advertise generation methods", name) + } + } +} diff --git a/backend/internal/pkg/geminicli/constants.go b/backend/internal/pkg/geminicli/constants.go index f5ee5735..97234ffd 100644 --- a/backend/internal/pkg/geminicli/constants.go +++ b/backend/internal/pkg/geminicli/constants.go @@ -39,7 +39,7 @@ const ( // They enable the "login without creating your own OAuth client" experience, but Google may // restrict which scopes are allowed for this client. GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - GeminiCLIOAuthClientSecret = "GOCSPX-your-client-secret" + GeminiCLIOAuthClientSecret = "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" // GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret. GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET" diff --git a/backend/internal/pkg/geminicli/models.go b/backend/internal/pkg/geminicli/models.go index 1fc4d983..195fb06f 100644 --- a/backend/internal/pkg/geminicli/models.go +++ b/backend/internal/pkg/geminicli/models.go @@ -13,10 +13,12 @@ type Model struct { var DefaultModels = []Model{ {ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""}, {ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""}, + {ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""}, {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""}, {ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""}, {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""}, {ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""}, + {ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""}, } // DefaultTestModel is the default model to preselect in test flows. diff --git a/backend/internal/pkg/geminicli/models_test.go b/backend/internal/pkg/geminicli/models_test.go new file mode 100644 index 00000000..c1884e2e --- /dev/null +++ b/backend/internal/pkg/geminicli/models_test.go @@ -0,0 +1,23 @@ +package geminicli + +import "testing" + +func TestDefaultModels_ContainsImageModels(t *testing.T) { + t.Parallel() + + byID := make(map[string]Model, len(DefaultModels)) + for _, model := range DefaultModels { + byID[model.ID] = model + } + + required := []string{ + "gemini-2.5-flash-image", + "gemini-3.1-flash-image", + } + + for _, id := range required { + if _, ok := byID[id]; !ok { + t.Fatalf("expected curated Gemini model %q to exist", id) + } + } +} diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index 6ef3d714..32e4bc5b 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -18,11 +18,11 @@ package httpclient import ( "fmt" "net/http" - "net/url" "strings" "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) @@ -41,7 +41,6 @@ type Options struct { Timeout time.Duration // 请求总超时时间 ResponseHeaderTimeout time.Duration // 等待响应头超时时间 InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true) - ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding) AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用) @@ -120,15 +119,13 @@ func buildTransport(opts Options) (*http.Transport, error) { return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead") } - proxyURL := strings.TrimSpace(opts.ProxyURL) - if proxyURL == "" { - return transport, nil - } - - parsed, err := url.Parse(proxyURL) + _, parsed, err := proxyurl.Parse(opts.ProxyURL) if err != nil { return nil, err } + if parsed == nil { + return transport, nil + } if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { return nil, err @@ -138,12 +135,11 @@ func buildTransport(opts Options) (*http.Transport, error) { } func buildClientKey(opts Options) string { - return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d", + return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%d|%d|%d", strings.TrimSpace(opts.ProxyURL), opts.Timeout.String(), opts.ResponseHeaderTimeout.String(), opts.InsecureSkipVerify, - opts.ProxyStrict, opts.ValidateResolvedIP, opts.AllowPrivateHosts, opts.MaxIdleConns, diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index 4bbc68e7..b0a31a5f 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -15,6 +15,7 @@ type Model struct { // DefaultModels OpenAI models list var DefaultModels = []Model{ + {ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"}, {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index 8bdcbe16..a35a5ea6 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -268,6 +268,7 @@ type IDTokenClaims struct { type OpenAIAuthClaims struct { ChatGPTAccountID string `json:"chatgpt_account_id"` ChatGPTUserID string `json:"chatgpt_user_id"` + ChatGPTPlanType string `json:"chatgpt_plan_type"` UserID string `json:"user_id"` Organizations []OrganizationClaim `json:"organizations"` } @@ -325,12 +326,9 @@ func (r *RefreshTokenRequest) ToFormData() string { return params.Encode() } -// ParseIDToken parses the ID Token JWT and extracts claims. -// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。 -// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名: -// -// https://auth.openai.com/.well-known/jwks.json -func ParseIDToken(idToken string) (*IDTokenClaims, error) { +// DecodeIDToken decodes the ID Token JWT payload without validating expiration. +// Use this for best-effort extraction (e.g., during data import) where the token may be expired. +func DecodeIDToken(idToken string) (*IDTokenClaims, error) { parts := strings.Split(idToken, ".") if len(parts) != 3 { return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) @@ -360,6 +358,20 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) { return nil, fmt.Errorf("failed to parse JWT claims: %w", err) } + return &claims, nil +} + +// ParseIDToken parses the ID Token JWT and extracts claims. +// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。 +// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名: +// +// https://auth.openai.com/.well-known/jwks.json +func ParseIDToken(idToken string) (*IDTokenClaims, error) { + claims, err := DecodeIDToken(idToken) + if err != nil { + return nil, err + } + // 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌) const clockSkewTolerance = 120 // 秒 now := time.Now().Unix() @@ -367,7 +379,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) { return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance) } - return &claims, nil + return claims, nil } // UserInfo represents user information extracted from ID Token claims. @@ -375,6 +387,7 @@ type UserInfo struct { Email string ChatGPTAccountID string ChatGPTUserID string + PlanType string UserID string OrganizationID string Organizations []OrganizationClaim @@ -389,6 +402,7 @@ func (c *IDTokenClaims) GetUserInfo() *UserInfo { if c.OpenAIAuth != nil { info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID + info.PlanType = c.OpenAIAuth.ChatGPTPlanType info.UserID = c.OpenAIAuth.UserID info.Organizations = c.OpenAIAuth.Organizations diff --git a/backend/internal/pkg/openai/request.go b/backend/internal/pkg/openai/request.go index c24d1273..dd8fe566 100644 --- a/backend/internal/pkg/openai/request.go +++ b/backend/internal/pkg/openai/request.go @@ -58,6 +58,12 @@ func IsCodexOfficialClientOriginator(originator string) bool { return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes) } +// IsCodexOfficialClientByHeaders checks whether the request headers indicate an +// official Codex client family request. +func IsCodexOfficialClientByHeaders(userAgent, originator string) bool { + return IsCodexOfficialClientRequest(userAgent) || IsCodexOfficialClientOriginator(originator) +} + func normalizeCodexClientHeader(value string) string { return strings.ToLower(strings.TrimSpace(value)) } diff --git a/backend/internal/pkg/openai/request_test.go b/backend/internal/pkg/openai/request_test.go index 508bf561..b4562a07 100644 --- a/backend/internal/pkg/openai/request_test.go +++ b/backend/internal/pkg/openai/request_test.go @@ -85,3 +85,26 @@ func TestIsCodexOfficialClientOriginator(t *testing.T) { }) } } + +func TestIsCodexOfficialClientByHeaders(t *testing.T) { + tests := []struct { + name string + ua string + originator string + want bool + }{ + {name: "仅 originator 命中 desktop", originator: "Codex Desktop", want: true}, + {name: "仅 originator 命中 vscode", originator: "codex_vscode", want: true}, + {name: "仅 ua 命中 desktop", ua: "Codex Desktop/1.2.3", want: true}, + {name: "ua 与 originator 都未命中", ua: "curl/8.0.1", originator: "my_client", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexOfficialClientByHeaders(tt.ua, tt.originator) + if got != tt.want { + t.Fatalf("IsCodexOfficialClientByHeaders(%q, %q) = %v, want %v", tt.ua, tt.originator, got, tt.want) + } + }) + } +} diff --git a/backend/internal/pkg/proxyurl/parse.go b/backend/internal/pkg/proxyurl/parse.go new file mode 100644 index 00000000..217556f2 --- /dev/null +++ b/backend/internal/pkg/proxyurl/parse.go @@ -0,0 +1,66 @@ +// Package proxyurl 提供代理 URL 的统一验证(fail-fast,无效代理不回退直连) +// +// 所有需要解析代理 URL 的地方必须通过此包的 Parse 函数。 +// 直接使用 url.Parse 处理代理 URL 是被禁止的。 +// 这确保了 fail-fast 行为:无效代理配置在创建时立即失败, +// 而不是在运行时静默回退到直连(产生 IP 关联风险)。 +package proxyurl + +import ( + "fmt" + "net/url" + "strings" +) + +// allowedSchemes 代理协议白名单 +var allowedSchemes = map[string]bool{ + "http": true, + "https": true, + "socks5": true, + "socks5h": true, +} + +// Parse 解析并验证代理 URL。 +// +// 语义: +// - 空字符串 → ("", nil, nil),表示直连 +// - 非空且有效 → (trimmed, *url.URL, nil) +// - 非空但无效 → ("", nil, error),fail-fast 不回退 +// +// 验证规则: +// - TrimSpace 后为空视为直连 +// - url.Parse 失败返回 error(不含原始 URL,防凭据泄露) +// - Host 为空返回 error(用 Redacted() 脱敏) +// - Scheme 必须为 http/https/socks5/socks5h +// - socks5:// 自动升级为 socks5h://(确保 DNS 由代理端解析,防止 DNS 泄漏) +func Parse(raw string) (trimmed string, parsed *url.URL, err error) { + trimmed = strings.TrimSpace(raw) + if trimmed == "" { + return "", nil, nil + } + + parsed, err = url.Parse(trimmed) + if err != nil { + // 不使用 %w 包装,避免 url.Parse 的底层错误消息泄漏原始 URL(可能含凭据) + return "", nil, fmt.Errorf("invalid proxy URL: %v", err) + } + + if parsed.Host == "" || parsed.Hostname() == "" { + return "", nil, fmt.Errorf("proxy URL missing host: %s", parsed.Redacted()) + } + + scheme := strings.ToLower(parsed.Scheme) + if !allowedSchemes[scheme] { + return "", nil, fmt.Errorf("unsupported proxy scheme %q (allowed: http, https, socks5, socks5h)", scheme) + } + + // 自动升级 socks5 → socks5h,确保 DNS 由代理端解析,防止 DNS 泄漏。 + // Go 的 golang.org/x/net/proxy 对 socks5:// 默认在客户端本地解析 DNS, + // 仅 socks5h:// 才将域名发送给代理端做远程 DNS 解析。 + if scheme == "socks5" { + parsed.Scheme = "socks5h" + trimmed = parsed.String() + } + + return trimmed, parsed, nil +} diff --git a/backend/internal/pkg/proxyurl/parse_test.go b/backend/internal/pkg/proxyurl/parse_test.go new file mode 100644 index 00000000..5fb57c16 --- /dev/null +++ b/backend/internal/pkg/proxyurl/parse_test.go @@ -0,0 +1,215 @@ +package proxyurl + +import ( + "strings" + "testing" +) + +func TestParse_空字符串直连(t *testing.T) { + trimmed, parsed, err := Parse("") + if err != nil { + t.Fatalf("空字符串应直连: %v", err) + } + if trimmed != "" { + t.Errorf("trimmed 应为空: got %q", trimmed) + } + if parsed != nil { + t.Errorf("parsed 应为 nil: got %v", parsed) + } +} + +func TestParse_空白字符串直连(t *testing.T) { + trimmed, parsed, err := Parse(" ") + if err != nil { + t.Fatalf("空白字符串应直连: %v", err) + } + if trimmed != "" { + t.Errorf("trimmed 应为空: got %q", trimmed) + } + if parsed != nil { + t.Errorf("parsed 应为 nil: got %v", parsed) + } +} + +func TestParse_有效HTTP代理(t *testing.T) { + trimmed, parsed, err := Parse("http://proxy.example.com:8080") + if err != nil { + t.Fatalf("有效 HTTP 代理应成功: %v", err) + } + if trimmed != "http://proxy.example.com:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } + if parsed == nil { + t.Fatal("parsed 不应为 nil") + } + if parsed.Host != "proxy.example.com:8080" { + t.Errorf("Host 不匹配: got %q", parsed.Host) + } +} + +func TestParse_有效HTTPS代理(t *testing.T) { + _, parsed, err := Parse("https://proxy.example.com:443") + if err != nil { + t.Fatalf("有效 HTTPS 代理应成功: %v", err) + } + if parsed.Scheme != "https" { + t.Errorf("Scheme 不匹配: got %q", parsed.Scheme) + } +} + +func TestParse_有效SOCKS5代理_自动升级为SOCKS5H(t *testing.T) { + trimmed, parsed, err := Parse("socks5://127.0.0.1:1080") + if err != nil { + t.Fatalf("有效 SOCKS5 代理应成功: %v", err) + } + // socks5 自动升级为 socks5h,确保 DNS 由代理端解析 + if trimmed != "socks5h://127.0.0.1:1080" { + t.Errorf("trimmed 应升级为 socks5h: got %q", trimmed) + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } +} + +func TestParse_无效URL(t *testing.T) { + _, _, err := Parse("://invalid") + if err == nil { + t.Fatal("无效 URL 应返回错误") + } + if !strings.Contains(err.Error(), "invalid proxy URL") { + t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error()) + } +} + +func TestParse_缺少Host(t *testing.T) { + _, _, err := Parse("http://") + if err == nil { + t.Fatal("缺少 host 应返回错误") + } + if !strings.Contains(err.Error(), "missing host") { + t.Errorf("错误信息应包含 'missing host': got %s", err.Error()) + } +} + +func TestParse_不支持的Scheme(t *testing.T) { + _, _, err := Parse("ftp://proxy.example.com:21") + if err == nil { + t.Fatal("不支持的 scheme 应返回错误") + } + if !strings.Contains(err.Error(), "unsupported proxy scheme") { + t.Errorf("错误信息应包含 'unsupported proxy scheme': got %s", err.Error()) + } +} + +func TestParse_含密码URL脱敏(t *testing.T) { + // 场景 1: 带密码的 socks5 URL 应成功解析并升级为 socks5h + trimmed, parsed, err := Parse("socks5://user:secret_password@proxy.local:1080") + if err != nil { + t.Fatalf("含密码的有效 URL 应成功: %v", err) + } + if trimmed == "" || parsed == nil { + t.Fatal("应返回非空结果") + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } + if !strings.HasPrefix(trimmed, "socks5h://") { + t.Errorf("trimmed 应以 socks5h:// 开头: got %q", trimmed) + } + if parsed.User == nil { + t.Error("升级后应保留 UserInfo") + } + + // 场景 2: 带密码但缺少 host(触发 Redacted 脱敏路径) + _, _, err = Parse("http://user:secret_password@:0/") + if err == nil { + t.Fatal("缺少 host 应返回错误") + } + if strings.Contains(err.Error(), "secret_password") { + t.Error("错误信息不应包含明文密码") + } + if !strings.Contains(err.Error(), "missing host") { + t.Errorf("错误信息应包含 'missing host': got %s", err.Error()) + } +} + +func TestParse_带空白的有效URL(t *testing.T) { + trimmed, parsed, err := Parse(" http://proxy.example.com:8080 ") + if err != nil { + t.Fatalf("带空白的有效 URL 应成功: %v", err) + } + if trimmed != "http://proxy.example.com:8080" { + t.Errorf("trimmed 应去除空白: got %q", trimmed) + } + if parsed == nil { + t.Fatal("parsed 不应为 nil") + } +} + +func TestParse_Scheme大小写不敏感(t *testing.T) { + // 大写 SOCKS5 应被接受并升级为 socks5h + trimmed, parsed, err := Parse("SOCKS5://proxy.example.com:1080") + if err != nil { + t.Fatalf("大写 SOCKS5 应被接受: %v", err) + } + if parsed.Scheme != "socks5h" { + t.Errorf("大写 SOCKS5 Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } + if !strings.HasPrefix(trimmed, "socks5h://") { + t.Errorf("大写 SOCKS5 trimmed 应升级为 socks5h://: got %q", trimmed) + } + + // 大写 HTTP 应被接受(不变) + _, _, err = Parse("HTTP://proxy.example.com:8080") + if err != nil { + t.Fatalf("大写 HTTP 应被接受: %v", err) + } +} + +func TestParse_带认证的有效代理(t *testing.T) { + trimmed, parsed, err := Parse("http://user:pass@proxy.example.com:8080") + if err != nil { + t.Fatalf("带认证的代理 URL 应成功: %v", err) + } + if parsed.User == nil { + t.Error("应保留 UserInfo") + } + if trimmed != "http://user:pass@proxy.example.com:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } +} + +func TestParse_IPv6地址(t *testing.T) { + trimmed, parsed, err := Parse("http://[::1]:8080") + if err != nil { + t.Fatalf("IPv6 代理 URL 应成功: %v", err) + } + if parsed.Hostname() != "::1" { + t.Errorf("Hostname 不匹配: got %q", parsed.Hostname()) + } + if trimmed != "http://[::1]:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } +} + +func TestParse_SOCKS5H保持不变(t *testing.T) { + trimmed, parsed, err := Parse("socks5h://proxy.local:1080") + if err != nil { + t.Fatalf("有效 SOCKS5H 代理应成功: %v", err) + } + // socks5h 不需要升级,应保持原样 + if trimmed != "socks5h://proxy.local:1080" { + t.Errorf("trimmed 不应变化: got %q", trimmed) + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应保持 socks5h: got %q", parsed.Scheme) + } +} + +func TestParse_无Scheme裸地址(t *testing.T) { + // 无 scheme 的裸地址,Go url.Parse 将其视为 path,Host 为空 + _, _, err := Parse("proxy.example.com:8080") + if err == nil { + t.Fatal("无 scheme 的裸地址应返回错误") + } +} diff --git a/backend/internal/pkg/proxyutil/dialer.go b/backend/internal/pkg/proxyutil/dialer.go index 91b224a2..e437cae3 100644 --- a/backend/internal/pkg/proxyutil/dialer.go +++ b/backend/internal/pkg/proxyutil/dialer.go @@ -2,7 +2,11 @@ // // 支持的代理协议: // - HTTP/HTTPS: 通过 Transport.Proxy 设置 -// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS) +// - SOCKS5: 通过 Transport.DialContext 设置(客户端本地解析 DNS) +// - SOCKS5H: 通过 Transport.DialContext 设置(代理端远程解析 DNS,推荐) +// +// 注意:proxyurl.Parse() 会自动将 socks5:// 升级为 socks5h://, +// 确保 DNS 也由代理端解析,防止 DNS 泄漏。 package proxyutil import ( @@ -20,7 +24,8 @@ import ( // // 支持的协议: // - http/https: 设置 transport.Proxy -// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS) +// - socks5: 设置 transport.DialContext(客户端本地解析 DNS) +// - socks5h: 设置 transport.DialContext(代理端远程解析 DNS,推荐) // // 参数: // - transport: 需要配置的 http.Transport diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 314a6d3c..99c9cda7 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -57,25 +57,37 @@ type DashboardStats struct { // TrendDataPoint represents a single point in trend data type TrendDataPoint struct { - Date string `json:"date"` - Requests int64 `json:"requests"` - InputTokens int64 `json:"input_tokens"` - OutputTokens int64 `json:"output_tokens"` - CacheTokens int64 `json:"cache_tokens"` - TotalTokens int64 `json:"total_tokens"` - Cost float64 `json:"cost"` // 标准计费 - ActualCost float64 `json:"actual_cost"` // 实际扣除 + Date string `json:"date"` + Requests int64 `json:"requests"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheCreationTokens int64 `json:"cache_creation_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + TotalTokens int64 `json:"total_tokens"` + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 } // ModelStat represents usage statistics for a single model type ModelStat struct { - Model string `json:"model"` - Requests int64 `json:"requests"` - InputTokens int64 `json:"input_tokens"` - OutputTokens int64 `json:"output_tokens"` - TotalTokens int64 `json:"total_tokens"` - Cost float64 `json:"cost"` // 标准计费 - ActualCost float64 `json:"actual_cost"` // 实际扣除 + Model string `json:"model"` + Requests int64 `json:"requests"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheCreationTokens int64 `json:"cache_creation_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + TotalTokens int64 `json:"total_tokens"` + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 +} + +// EndpointStat represents usage statistics for a single request endpoint. +type EndpointStat struct { + Endpoint string `json:"endpoint"` + Requests int64 `json:"requests"` + TotalTokens int64 `json:"total_tokens"` + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 } // GroupStat represents usage statistics for a single group @@ -93,12 +105,30 @@ type UserUsageTrendPoint struct { Date string `json:"date"` UserID int64 `json:"user_id"` Email string `json:"email"` + Username string `json:"username"` Requests int64 `json:"requests"` Tokens int64 `json:"tokens"` Cost float64 `json:"cost"` // 标准计费 ActualCost float64 `json:"actual_cost"` // 实际扣除 } +// UserSpendingRankingItem represents a user spending ranking row. +type UserSpendingRankingItem struct { + UserID int64 `json:"user_id"` + Email string `json:"email"` + ActualCost float64 `json:"actual_cost"` // 实际扣除 + Requests int64 `json:"requests"` + Tokens int64 `json:"tokens"` +} + +// UserSpendingRankingResponse represents ranking rows plus total spend for the time range. +type UserSpendingRankingResponse struct { + Ranking []UserSpendingRankingItem `json:"ranking"` + TotalActualCost float64 `json:"total_actual_cost"` + TotalRequests int64 `json:"total_requests"` + TotalTokens int64 `json:"total_tokens"` +} + // APIKeyUsageTrendPoint represents API key usage trend data point type APIKeyUsageTrendPoint struct { Date string `json:"date"` @@ -154,19 +184,24 @@ type UsageLogFilters struct { BillingType *int8 StartTime *time.Time EndTime *time.Time + // ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging. + ExactTotal bool } // UsageStats represents usage statistics type UsageStats struct { - TotalRequests int64 `json:"total_requests"` - TotalInputTokens int64 `json:"total_input_tokens"` - TotalOutputTokens int64 `json:"total_output_tokens"` - TotalCacheTokens int64 `json:"total_cache_tokens"` - TotalTokens int64 `json:"total_tokens"` - TotalCost float64 `json:"total_cost"` - TotalActualCost float64 `json:"total_actual_cost"` - TotalAccountCost *float64 `json:"total_account_cost,omitempty"` - AverageDurationMs float64 `json:"average_duration_ms"` + TotalRequests int64 `json:"total_requests"` + TotalInputTokens int64 `json:"total_input_tokens"` + TotalOutputTokens int64 `json:"total_output_tokens"` + TotalCacheTokens int64 `json:"total_cache_tokens"` + TotalTokens int64 `json:"total_tokens"` + TotalCost float64 `json:"total_cost"` + TotalActualCost float64 `json:"total_actual_cost"` + TotalAccountCost *float64 `json:"total_account_cost,omitempty"` + AverageDurationMs float64 `json:"average_duration_ms"` + Endpoints []EndpointStat `json:"endpoints,omitempty"` + UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"` + EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"` } // BatchUserUsageStats represents usage stats for a single user @@ -233,7 +268,9 @@ type AccountUsageSummary struct { // AccountUsageStatsResponse represents the full usage statistics response for an account type AccountUsageStatsResponse struct { - History []AccountUsageHistory `json:"history"` - Summary AccountUsageSummary `json:"summary"` - Models []ModelStat `json:"models"` + History []AccountUsageHistory `json:"history"` + Summary AccountUsageSummary `json:"summary"` + Models []ModelStat `json:"models"` + Endpoints []EndpointStat `json:"endpoints"` + UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"` } diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 4aa74928..20ff7373 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -16,6 +16,7 @@ import ( "encoding/json" "errors" "strconv" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -50,6 +51,18 @@ type accountRepository struct { schedulerCache service.SchedulerCache } +var schedulerNeutralExtraKeyPrefixes = []string{ + "codex_primary_", + "codex_secondary_", + "codex_5h_", + "codex_7d_", +} + +var schedulerNeutralExtraKeys = map[string]struct{}{ + "codex_usage_updated_at": {}, + "session_window_utilization": {}, +} + // NewAccountRepository 创建账户仓储实例。 // 这是对外暴露的构造函数,返回接口类型以便于依赖注入。 func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository { @@ -84,6 +97,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account if account.RateMultiplier != nil { builder.SetRateMultiplier(*account.RateMultiplier) } + if account.LoadFactor != nil { + builder.SetLoadFactor(*account.LoadFactor) + } if account.ProxyID != nil { builder.SetProxyID(*account.ProxyID) @@ -318,6 +334,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account if account.RateMultiplier != nil { builder.SetRateMultiplier(*account.RateMultiplier) } + if account.LoadFactor != nil { + builder.SetLoadFactor(*account.LoadFactor) + } else { + builder.ClearLoadFactor() + } if account.ProxyID != nil { builder.SetProxyID(*account.ProxyID) @@ -376,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) } - if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable { - r.syncSchedulerAccountSnapshot(ctx, account.ID) - } + // 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照, + // 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。 + r.syncSchedulerAccountSnapshot(ctx, account.ID) return nil } @@ -437,6 +458,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati switch status { case "rate_limited": q = q.Where(dbaccount.RateLimitResetAtGT(time.Now())) + case "temp_unschedulable": + q = q.Where(dbpredicate.Account(func(s *entsql.Selector) { + col := s.C("temp_unschedulable_until") + s.Where(entsql.And( + entsql.Not(entsql.IsNull(col)), + entsql.GT(col, entsql.Expr("NOW()")), + )) + })) default: q = q.Where(dbaccount.StatusEQ(status)) } @@ -640,7 +669,14 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error { SetStatus(service.StatusActive). SetErrorMessage(""). Save(ctx) - return err + if err != nil { + return err + } + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear error failed: account=%d err=%v", id, err) + } + r.syncSchedulerAccountSnapshot(ctx, id) + return nil } func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { @@ -829,6 +865,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat return r.accountsToService(ctx, accounts) } +func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + now := time.Now() + accounts, err := r.client.Account.Query(). + Where( + dbaccount.PlatformEQ(platform), + dbaccount.StatusEQ(service.StatusActive), + dbaccount.SchedulableEQ(true), + dbaccount.Not(dbaccount.HasAccountGroups()), + tempUnschedulablePredicate(), + notExpiredPredicate(now), + dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), + dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), + ). + Order(dbent.Asc(dbaccount.FieldPriority)). + All(ctx) + if err != nil { + return nil, err + } + return r.accountsToService(ctx, accounts) +} + +func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + if len(platforms) == 0 { + return nil, nil + } + now := time.Now() + accounts, err := r.client.Account.Query(). + Where( + dbaccount.PlatformIn(platforms...), + dbaccount.StatusEQ(service.StatusActive), + dbaccount.SchedulableEQ(true), + dbaccount.Not(dbaccount.HasAccountGroups()), + tempUnschedulablePredicate(), + notExpiredPredicate(now), + dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), + dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), + ). + Order(dbent.Asc(dbaccount.FieldPriority)). + All(ctx) + if err != nil { + return nil, err + } + return r.accountsToService(ctx, accounts) +} + func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { if len(platforms) == 0 { return nil, nil @@ -854,6 +935,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err) } + r.syncSchedulerAccountSnapshot(ctx, id) return nil } @@ -969,6 +1051,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err) } + r.syncSchedulerAccountSnapshot(ctx, id) return nil } @@ -1115,12 +1198,48 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m if affected == 0 { return service.ErrAccountNotFound } - if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) + if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) + } + } else { + // 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照, + // 让 sticky session / GetAccount 命中缓存时也能读到最新数据, + // 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。 + r.syncSchedulerAccountSnapshot(ctx, id) } return nil } +func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool { + if len(updates) == 0 { + return false + } + for key := range updates { + if isSchedulerNeutralExtraKey(key) { + continue + } + return true + } + return false +} + +func isSchedulerNeutralExtraKey(key string) bool { + key = strings.TrimSpace(key) + if key == "" { + return false + } + if _, ok := schedulerNeutralExtraKeys[key]; ok { + return true + } + for _, prefix := range schedulerNeutralExtraKeyPrefixes { + if strings.HasPrefix(key, prefix) { + return true + } + } + return false +} + func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { if len(ids) == 0 { return 0, nil @@ -1160,6 +1279,15 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates args = append(args, *updates.RateMultiplier) idx++ } + if updates.LoadFactor != nil { + if *updates.LoadFactor <= 0 { + setClauses = append(setClauses, "load_factor = NULL") + } else { + setClauses = append(setClauses, "load_factor = $"+itoa(idx)) + args = append(args, *updates.LoadFactor) + idx++ + } + } if updates.Status != nil { setClauses = append(setClauses, "status = $"+itoa(idx)) args = append(args, *updates.Status) @@ -1482,6 +1610,7 @@ func accountEntityToService(m *dbent.Account) *service.Account { Concurrency: m.Concurrency, Priority: m.Priority, RateMultiplier: &rateMultiplier, + LoadFactor: m.LoadFactor, Status: m.Status, ErrorMessage: derefString(m.ErrorMessage), LastUsedAt: m.LastUsedAt, @@ -1594,3 +1723,186 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va return r.accountsToService(ctx, accounts) } + +// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string. +const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')` + +// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired. +// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes. +const dailyExpiredExpr = `( + CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed' + THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz) + ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + END +)` + +// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired. +const weeklyExpiredExpr = `( + CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed' + THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz) + ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + END +)` + +// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs. +// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour. +// This correctly handles long-inactive accounts by jumping directly to the next valid reset point. +const nextDailyResetAtExpr = `( + CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed' + THEN to_char(( + -- Compute today's reset point in the configured timezone, then pick next future one + CASE WHEN NOW() >= ( + date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')) + + (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval + ) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC') + -- NOW() is at or past today's reset point → next reset is tomorrow + THEN ( + date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')) + + (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval + + '1 day'::interval + ) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC') + -- NOW() is before today's reset point → next reset is today + ELSE ( + date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')) + + (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval + ) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC') + END + ) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"') + ELSE NULL END +)` + +// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs. +// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour. +// This correctly handles long-inactive accounts by jumping directly to the next valid reset point. +const nextWeeklyResetAtExpr = `( + CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed' + THEN to_char(( + -- Compute this week's reset point in the configured timezone + -- Step 1: get today's date at reset hour in configured tz + -- Step 2: compute days forward to target weekday + -- Step 3: if same day but past reset hour, advance 7 days + CASE + WHEN ( + -- days_forward = (target_day - current_day + 7) % 7 + (COALESCE((extra->>'quota_weekly_reset_day')::int, 1) + - EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int + + 7) % 7 + ) = 0 AND NOW() >= ( + date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')) + + (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval + ) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC') + -- Same weekday and past reset hour → next week + THEN ( + date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')) + + (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval + + '7 days'::interval + ) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC') + ELSE ( + -- Advance to target weekday this week (or next if days_forward > 0) + date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')) + + (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval + + (( + (COALESCE((extra->>'quota_weekly_reset_day')::int, 1) + - EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int + + 7) % 7 + ) || ' days')::interval + ) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC') + END + ) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"') + ELSE NULL END +)` + +// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度) +// 日/周额度在周期过期时自动重置为 0 再递增。 +// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。 +func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + rows, err := r.sql.QueryContext(ctx, + `UPDATE accounts SET extra = ( + COALESCE(extra, '{}'::jsonb) + -- 总额度:始终递增 + || jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1) + -- 日额度:仅在 quota_daily_limit > 0 时处理 + || CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN + jsonb_build_object( + 'quota_daily_used', + CASE WHEN `+dailyExpiredExpr+` + THEN $1 + ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END, + 'quota_daily_start', + CASE WHEN `+dailyExpiredExpr+` + THEN `+nowUTC+` + ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END + ) + -- 固定模式重置时更新下次重置时间 + || CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL + THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`) + ELSE '{}'::jsonb END + ELSE '{}'::jsonb END + -- 周额度:仅在 quota_weekly_limit > 0 时处理 + || CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN + jsonb_build_object( + 'quota_weekly_used', + CASE WHEN `+weeklyExpiredExpr+` + THEN $1 + ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END, + 'quota_weekly_start', + CASE WHEN `+weeklyExpiredExpr+` + THEN `+nowUTC+` + ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END + ) + -- 固定模式重置时更新下次重置时间 + || CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL + THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`) + ELSE '{}'::jsonb END + ELSE '{}'::jsonb END + ), updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + RETURNING + COALESCE((extra->>'quota_used')::numeric, 0), + COALESCE((extra->>'quota_limit')::numeric, 0)`, + amount, id) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + + var newUsed, limit float64 + if rows.Next() { + if err := rows.Scan(&newUsed, &limit); err != nil { + return err + } + } + if err := rows.Err(); err != nil { + return err + } + + // 任一维度配额刚超限时触发调度快照刷新 + if limit > 0 && newUsed >= limit && (newUsed-amount) < limit { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err) + } + } + return nil +} + +// ResetQuotaUsed 重置账号所有维度的配额用量为 0 +// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间 +func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error { + _, err := r.sql.ExecContext(ctx, + `UPDATE accounts SET extra = ( + COALESCE(extra, '{}'::jsonb) + || '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb + ) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL`, + id) + if err != nil { + return err + } + // 重置配额后触发调度快照刷新,使账号重新参与调度 + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota reset failed: account=%d err=%v", id, err) + } + return nil +} diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index fd48a5d4..e697802e 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -23,6 +23,7 @@ type AccountRepoSuite struct { type schedulerCacheRecorder struct { setAccounts []*service.Account + accounts map[int64]*service.Account } func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) { @@ -34,11 +35,20 @@ func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service } func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) { - return nil, nil + if s.accounts == nil { + return nil, nil + } + return s.accounts[accountID], nil } func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error { s.setAccounts = append(s.setAccounts, account) + if s.accounts == nil { + s.accounts = make(map[int64]*service.Account) + } + if account != nil { + s.accounts[account.ID] = account + } return nil } @@ -132,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() { s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status) } +func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() { + account := mustCreateAccount(s.T(), s.client, &service.Account{ + Name: "sync-credentials-update", + Status: service.StatusActive, + Schedulable: true, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5": "gpt-5.1", + }, + }, + }) + cacheRecorder := &schedulerCacheRecorder{} + s.repo.schedulerCache = cacheRecorder + + account.Credentials = map[string]any{ + "model_mapping": map[string]any{ + "gpt-5": "gpt-5.2", + }, + } + err := s.repo.Update(s.ctx, account) + s.Require().NoError(err, "Update") + + s.Require().Len(cacheRecorder.setAccounts, 1) + s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID) + mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any) + s.Require().True(ok) + s.Require().Equal("gpt-5.2", mapping["gpt-5"]) +} + func (s *AccountRepoSuite) TestDelete() { account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"}) @@ -558,6 +597,26 @@ func (s *AccountRepoSuite) TestSetError() { s.Require().Equal("something went wrong", got.ErrorMessage) } +func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() { + account := mustCreateAccount(s.T(), s.client, &service.Account{ + Name: "acc-clear-err", + Status: service.StatusError, + ErrorMessage: "temporary error", + }) + cacheRecorder := &schedulerCacheRecorder{} + s.repo.schedulerCache = cacheRecorder + + s.Require().NoError(s.repo.ClearError(s.ctx, account.ID)) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().Equal(service.StatusActive, got.Status) + s.Require().Empty(got.ErrorMessage) + s.Require().Len(cacheRecorder.setAccounts, 1) + s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID) + s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status) +} + // --- UpdateSessionWindow --- func (s *AccountRepoSuite) TestUpdateSessionWindow() { @@ -603,6 +662,96 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { s.Require().Equal("val", got.Extra["key"]) } +func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() { + account := mustCreateAccount(s.T(), s.client, &service.Account{ + Name: "acc-extra-neutral", + Platform: service.PlatformOpenAI, + Extra: map[string]any{"codex_usage_updated_at": "old"}, + }) + cacheRecorder := &schedulerCacheRecorder{ + accounts: map[int64]*service.Account{ + account.ID: { + ID: account.ID, + Platform: account.Platform, + Status: service.StatusDisabled, + Extra: map[string]any{ + "codex_usage_updated_at": "old", + }, + }, + }, + } + s.repo.schedulerCache = cacheRecorder + + updates := map[string]any{ + "codex_usage_updated_at": "2026-03-11T10:00:00Z", + "codex_5h_used_percent": 88.5, + "session_window_utilization": 0.42, + } + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates)) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().Equal("2026-03-11T10:00:00Z", got.Extra["codex_usage_updated_at"]) + s.Require().Equal(88.5, got.Extra["codex_5h_used_percent"]) + s.Require().Equal(0.42, got.Extra["session_window_utilization"]) + + var outboxCount int + s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount)) + s.Require().Zero(outboxCount) + s.Require().Len(cacheRecorder.setAccounts, 1) + s.Require().NotNil(cacheRecorder.accounts[account.ID]) + s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status) + s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"]) +} + +func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() { + account := mustCreateAccount(s.T(), s.client, &service.Account{ + Name: "acc-extra-codex-exhausted", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Extra: map[string]any{}, + }) + cacheRecorder := &schedulerCacheRecorder{} + s.repo.schedulerCache = cacheRecorder + _, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox") + s.Require().NoError(err) + + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": "2026-03-12T13:00:00Z", + "codex_7d_reset_after_seconds": 86400, + })) + + var count int + err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count) + s.Require().NoError(err) + s.Require().Equal(0, count) + s.Require().Len(cacheRecorder.setAccounts, 1) + s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID) + s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status) + s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"]) +} + +func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() { + account := mustCreateAccount(s.T(), s.client, &service.Account{ + Name: "acc-extra-mixed", + Platform: service.PlatformAntigravity, + Extra: map[string]any{}, + }) + _, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox") + s.Require().NoError(err) + + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{ + "mixed_scheduling": true, + "codex_usage_updated_at": "2026-03-11T10:00:00Z", + })) + + var count int + err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count) + s.Require().NoError(err) + s.Require().Equal(1, count) +} + // --- GetByCRSAccountID --- func (s *AccountRepoSuite) TestGetByCRSAccountID() { diff --git a/backend/internal/repository/allowed_groups_contract_integration_test.go b/backend/internal/repository/allowed_groups_contract_integration_test.go index 0d0f11e5..b0af0d54 100644 --- a/backend/internal/repository/allowed_groups_contract_integration_test.go +++ b/backend/internal/repository/allowed_groups_contract_integration_test.go @@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t userRepo := newUserRepositoryWithSQL(entClient, tx) groupRepo := newGroupRepositoryWithSQL(entClient, tx) - apiKeyRepo := NewAPIKeyRepository(entClient) + apiKeyRepo := newAPIKeyRepositoryWithSQL(entClient, tx) u := &service.User{ Email: uniqueTestValue(t, "cascade-user") + "@example.com", diff --git a/backend/internal/repository/announcement_repo.go b/backend/internal/repository/announcement_repo.go index 52029e4e..53dc335f 100644 --- a/backend/internal/repository/announcement_repo.go +++ b/backend/internal/repository/announcement_repo.go @@ -24,6 +24,7 @@ func (r *announcementRepository) Create(ctx context.Context, a *service.Announce SetTitle(a.Title). SetContent(a.Content). SetStatus(a.Status). + SetNotifyMode(a.NotifyMode). SetTargeting(a.Targeting) if a.StartsAt != nil { @@ -64,6 +65,7 @@ func (r *announcementRepository) Update(ctx context.Context, a *service.Announce SetTitle(a.Title). SetContent(a.Content). SetStatus(a.Status). + SetNotifyMode(a.NotifyMode). SetTargeting(a.Targeting) if a.StartsAt != nil { @@ -169,17 +171,18 @@ func announcementEntityToService(m *dbent.Announcement) *service.Announcement { return nil } return &service.Announcement{ - ID: m.ID, - Title: m.Title, - Content: m.Content, - Status: m.Status, - Targeting: m.Targeting, - StartsAt: m.StartsAt, - EndsAt: m.EndsAt, - CreatedBy: m.CreatedBy, - UpdatedBy: m.UpdatedBy, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, + ID: m.ID, + Title: m.Title, + Content: m.Content, + Status: m.Status, + NotifyMode: m.NotifyMode, + Targeting: m.Targeting, + StartsAt: m.StartsAt, + EndsAt: m.EndsAt, + CreatedBy: m.CreatedBy, + UpdatedBy: m.UpdatedBy, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, } } diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index b9ce60a5..4c7f38a8 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -2,6 +2,7 @@ package repository import ( "context" + "database/sql" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -16,10 +17,15 @@ import ( type apiKeyRepository struct { client *dbent.Client + sql sqlExecutor } -func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository { - return &apiKeyRepository{client: client} +func NewAPIKeyRepository(client *dbent.Client, sqlDB *sql.DB) service.APIKeyRepository { + return newAPIKeyRepositoryWithSQL(client, sqlDB) +} + +func newAPIKeyRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *apiKeyRepository { + return &apiKeyRepository{client: client, sql: sqlq} } func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery { @@ -37,7 +43,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro SetNillableLastUsedAt(key.LastUsedAt). SetQuota(key.Quota). SetQuotaUsed(key.QuotaUsed). - SetNillableExpiresAt(key.ExpiresAt) + SetNillableExpiresAt(key.ExpiresAt). + SetRateLimit5h(key.RateLimit5h). + SetRateLimit1d(key.RateLimit1d). + SetRateLimit7d(key.RateLimit7d) if len(key.IPWhitelist) > 0 { builder.SetIPWhitelist(key.IPWhitelist) @@ -118,6 +127,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se apikey.FieldQuota, apikey.FieldQuotaUsed, apikey.FieldExpiresAt, + apikey.FieldRateLimit5h, + apikey.FieldRateLimit1d, + apikey.FieldRateLimit7d, ). WithUser(func(q *dbent.UserQuery) { q.Select( @@ -153,6 +165,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldModelRouting, group.FieldMcpXMLInject, group.FieldSupportedModelScopes, + group.FieldAllowMessagesDispatch, + group.FieldDefaultMappedModel, ) }). Only(ctx) @@ -179,6 +193,12 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro SetStatus(key.Status). SetQuota(key.Quota). SetQuotaUsed(key.QuotaUsed). + SetRateLimit5h(key.RateLimit5h). + SetRateLimit1d(key.RateLimit1d). + SetRateLimit7d(key.RateLimit7d). + SetUsage5h(key.Usage5h). + SetUsage1d(key.Usage1d). + SetUsage7d(key.Usage7d). SetUpdatedAt(now) if key.GroupID != nil { builder.SetGroupID(*key.GroupID) @@ -193,6 +213,23 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro builder.ClearExpiresAt() } + // Rate limit window start times + if key.Window5hStart != nil { + builder.SetWindow5hStart(*key.Window5hStart) + } else { + builder.ClearWindow5hStart() + } + if key.Window1dStart != nil { + builder.SetWindow1dStart(*key.Window1dStart) + } else { + builder.ClearWindow1dStart() + } + if key.Window7dStart != nil { + builder.SetWindow7dStart(*key.Window7dStart) + } else { + builder.ClearWindow7dStart() + } + // IP 限制字段 if len(key.IPWhitelist) > 0 { builder.SetIPWhitelist(key.IPWhitelist) @@ -246,9 +283,27 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { return nil } -func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { +func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { q := r.activeQuery().Where(apikey.UserIDEQ(userID)) + // Apply filters + if filters.Search != "" { + q = q.Where(apikey.Or( + apikey.NameContainsFold(filters.Search), + apikey.KeyContainsFold(filters.Search), + )) + } + if filters.Status != "" { + q = q.Where(apikey.StatusEQ(filters.Status)) + } + if filters.GroupID != nil { + if *filters.GroupID == 0 { + q = q.Where(apikey.GroupIDIsNil()) + } else { + q = q.Where(apikey.GroupIDEQ(*filters.GroupID)) + } + } + total, err := q.Count(ctx) if err != nil { return nil, nil, err @@ -397,6 +452,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo return updated.QuotaUsed, nil } +// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key +// as quota_exhausted, and returns the latest quota state in one round trip. +func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) { + query := ` + UPDATE api_keys + SET + quota_used = quota_used + $1, + status = CASE + WHEN quota > 0 AND quota_used + $1 >= quota THEN $2 + ELSE status + END, + updated_at = NOW() + WHERE id = $3 AND deleted_at IS NULL + RETURNING quota_used, quota, key, status + ` + + state := &service.APIKeyQuotaUsageState{} + if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil { + if err == sql.ErrNoRows { + return nil, service.ErrAPIKeyNotFound + } + return nil, err + } + return state, nil +} + func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { affected, err := r.client.APIKey.Update(). Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). @@ -412,25 +493,92 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt return nil } +// IncrementRateLimitUsage atomically increments all rate limit usage counters and initializes +// window start times via COALESCE if not already set. +func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + _, err := r.sql.ExecContext(ctx, ` + UPDATE api_keys SET + usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END, + usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END, + usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END, + window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END, + window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END, + window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL`, + cost, id) + return err +} + +// ResetRateLimitWindows resets expired rate limit windows atomically. +func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64) error { + _, err := r.sql.ExecContext(ctx, ` + UPDATE api_keys SET + usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END, + window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END, + usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END, + window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END, + usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END, + window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END, + updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL`, + id) + return err +} + +// GetRateLimitData returns the current rate limit usage and window start times for an API key. +func (r *apiKeyRepository) GetRateLimitData(ctx context.Context, id int64) (result *service.APIKeyRateLimitData, err error) { + rows, err := r.sql.QueryContext(ctx, ` + SELECT usage_5h, usage_1d, usage_7d, window_5h_start, window_1d_start, window_7d_start + FROM api_keys + WHERE id = $1 AND deleted_at IS NULL`, + id) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() + if !rows.Next() { + return nil, service.ErrAPIKeyNotFound + } + data := &service.APIKeyRateLimitData{} + if err := rows.Scan(&data.Usage5h, &data.Usage1d, &data.Usage7d, &data.Window5hStart, &data.Window1dStart, &data.Window7dStart); err != nil { + return nil, err + } + return data, rows.Err() +} + func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { if m == nil { return nil } out := &service.APIKey{ - ID: m.ID, - UserID: m.UserID, - Key: m.Key, - Name: m.Name, - Status: m.Status, - IPWhitelist: m.IPWhitelist, - IPBlacklist: m.IPBlacklist, - LastUsedAt: m.LastUsedAt, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, - GroupID: m.GroupID, - Quota: m.Quota, - QuotaUsed: m.QuotaUsed, - ExpiresAt: m.ExpiresAt, + ID: m.ID, + UserID: m.UserID, + Key: m.Key, + Name: m.Name, + Status: m.Status, + IPWhitelist: m.IPWhitelist, + IPBlacklist: m.IPBlacklist, + LastUsedAt: m.LastUsedAt, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + GroupID: m.GroupID, + Quota: m.Quota, + QuotaUsed: m.QuotaUsed, + ExpiresAt: m.ExpiresAt, + RateLimit5h: m.RateLimit5h, + RateLimit1d: m.RateLimit1d, + RateLimit7d: m.RateLimit7d, + Usage5h: m.Usage5h, + Usage1d: m.Usage1d, + Usage7d: m.Usage7d, + Window5hStart: m.Window5hStart, + Window1dStart: m.Window1dStart, + Window7dStart: m.Window7dStart, } if m.Edges.User != nil { out.User = userEntityToService(m.Edges.User) @@ -499,6 +647,8 @@ func groupEntityToService(g *dbent.Group) *service.Group { MCPXMLInject: g.McpXMLInject, SupportedModelScopes: g.SupportedModelScopes, SortOrder: g.SortOrder, + AllowMessagesDispatch: g.AllowMessagesDispatch, + DefaultMappedModel: g.DefaultMappedModel, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 303d7126..a8989ff2 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -26,7 +26,7 @@ func (s *APIKeyRepoSuite) SetupTest() { s.ctx = context.Background() tx := testEntTx(s.T()) s.client = tx.Client() - s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository) + s.repo = newAPIKeyRepositoryWithSQL(s.client, tx) } func TestAPIKeyRepoSuite(t *testing.T) { @@ -158,7 +158,7 @@ func (s *APIKeyRepoSuite) TestListByUserID() { s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil) s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil) - keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) + keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{}) s.Require().NoError(err, "ListByUserID") s.Require().Len(keys, 2) s.Require().Equal(int64(2), page.Total) @@ -170,7 +170,7 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() { s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil) } - keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}) + keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}, service.APIKeyListFilters{}) s.Require().NoError(err) s.Require().Len(keys, 2) s.Require().Equal(int64(5), page.Total) @@ -314,7 +314,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() { s.Require().Equal(service.StatusDisabled, got2.Status) s.Require().Nil(got2.GroupID) - keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) + keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{}) s.Require().NoError(err, "ListByUserID") s.Require().Equal(int64(1), page.Total) s.Require().Len(keys, 1) @@ -417,11 +417,32 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() { s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound") } +func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() { + user := s.mustCreateUser("quota-state@test.com") + key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil) + key.Quota = 3 + key.QuotaUsed = 1 + s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota") + + state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5) + s.Require().NoError(err, "IncrementQuotaUsedAndGetState") + s.Require().NotNil(state) + s.Require().Equal(3.5, state.QuotaUsed) + s.Require().Equal(3.0, state.Quota) + s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status) + s.Require().Equal(key.Key, state.Key) + + got, err := s.repo.GetByID(s.ctx, key.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal(3.5, got.QuotaUsed) + s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status) +} + // TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。 // 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。 func TestIncrementQuotaUsed_Concurrent(t *testing.T) { client := testEntClient(t) - repo := NewAPIKeyRepository(client).(*apiKeyRepository) + repo := NewAPIKeyRepository(client, integrationDB).(*apiKeyRepository) ctx := context.Background() // 创建测试用户和 API Key diff --git a/backend/internal/repository/backup_pg_dumper.go b/backend/internal/repository/backup_pg_dumper.go new file mode 100644 index 00000000..e9a92ef2 --- /dev/null +++ b/backend/internal/repository/backup_pg_dumper.go @@ -0,0 +1,98 @@ +package repository + +import ( + "context" + "fmt" + "io" + "os/exec" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// PgDumper implements service.DBDumper using pg_dump/psql +type PgDumper struct { + cfg *config.DatabaseConfig +} + +// NewPgDumper creates a new PgDumper +func NewPgDumper(cfg *config.Config) service.DBDumper { + return &PgDumper{cfg: &cfg.Database} +} + +// Dump executes pg_dump and returns a streaming reader of the output +func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) { + args := []string{ + "-h", d.cfg.Host, + "-p", fmt.Sprintf("%d", d.cfg.Port), + "-U", d.cfg.User, + "-d", d.cfg.DBName, + "--no-owner", + "--no-acl", + "--clean", + "--if-exists", + } + + cmd := exec.CommandContext(ctx, "pg_dump", args...) + if d.cfg.Password != "" { + cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password) + } + if d.cfg.SSLMode != "" { + cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("create stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("start pg_dump: %w", err) + } + + // 返回一个 ReadCloser:读 stdout,关闭时等待进程退出 + return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil +} + +// Restore executes psql to restore from a streaming reader +func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error { + args := []string{ + "-h", d.cfg.Host, + "-p", fmt.Sprintf("%d", d.cfg.Port), + "-U", d.cfg.User, + "-d", d.cfg.DBName, + "--single-transaction", + } + + cmd := exec.CommandContext(ctx, "psql", args...) + if d.cfg.Password != "" { + cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password) + } + if d.cfg.SSLMode != "" { + cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode) + } + + cmd.Stdin = data + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("%v: %s", err, string(output)) + } + return nil +} + +// cmdReadCloser wraps a command stdout pipe and waits for the process on Close +type cmdReadCloser struct { + io.ReadCloser + cmd *exec.Cmd +} + +func (c *cmdReadCloser) Close() error { + // Close the pipe first + _ = c.ReadCloser.Close() + // Wait for the process to exit + if err := c.cmd.Wait(); err != nil { + return fmt.Errorf("pg_dump exited with error: %w", err) + } + return nil +} diff --git a/backend/internal/repository/backup_s3_store.go b/backend/internal/repository/backup_s3_store.go new file mode 100644 index 00000000..ba5434f5 --- /dev/null +++ b/backend/internal/repository/backup_s3_store.go @@ -0,0 +1,116 @@ +package repository + +import ( + "bytes" + "context" + "fmt" + "io" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage +type S3BackupStore struct { + client *s3.Client + bucket string +} + +// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores +func NewS3BackupStoreFactory() service.BackupObjectStoreFactory { + return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) { + region := cfg.Region + if region == "" { + region = "auto" // Cloudflare R2 默认 region + } + + awsCfg, err := awsconfig.LoadDefaultConfig(ctx, + awsconfig.WithRegion(region), + awsconfig.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""), + ), + ) + if err != nil { + return nil, fmt.Errorf("load aws config: %w", err) + } + + client := s3.NewFromConfig(awsCfg, func(o *s3.Options) { + if cfg.Endpoint != "" { + o.BaseEndpoint = &cfg.Endpoint + } + if cfg.ForcePathStyle { + o.UsePathStyle = true + } + o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware) + o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired + }) + + return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil + } +} + +func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) { + // 读取全部内容以获取大小(S3 PutObject 需要知道内容长度) + data, err := io.ReadAll(body) + if err != nil { + return 0, fmt.Errorf("read body: %w", err) + } + + _, err = s.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: &s.bucket, + Key: &key, + Body: bytes.NewReader(data), + ContentType: &contentType, + }) + if err != nil { + return 0, fmt.Errorf("S3 PutObject: %w", err) + } + return int64(len(data)), nil +} + +func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) { + result, err := s.client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: &s.bucket, + Key: &key, + }) + if err != nil { + return nil, fmt.Errorf("S3 GetObject: %w", err) + } + return result.Body, nil +} + +func (s *S3BackupStore) Delete(ctx context.Context, key string) error { + _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: &s.bucket, + Key: &key, + }) + return err +} + +func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) { + presignClient := s3.NewPresignClient(s.client) + result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{ + Bucket: &s.bucket, + Key: &key, + }, s3.WithPresignExpires(expiry)) + if err != nil { + return "", fmt.Errorf("presign url: %w", err) + } + return result.URL, nil +} + +func (s *S3BackupStore) HeadBucket(ctx context.Context) error { + _, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: &s.bucket, + }) + if err != nil { + return fmt.Errorf("S3 HeadBucket failed: %w", err) + } + return nil +} diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index e753e1b8..4fbdae14 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -14,10 +14,12 @@ import ( ) const ( - billingBalanceKeyPrefix = "billing:balance:" - billingSubKeyPrefix = "billing:sub:" - billingCacheTTL = 5 * time.Minute - billingCacheJitter = 30 * time.Second + billingBalanceKeyPrefix = "billing:balance:" + billingSubKeyPrefix = "billing:sub:" + billingRateLimitKeyPrefix = "apikey:rate:" + billingCacheTTL = 5 * time.Minute + billingCacheJitter = 30 * time.Second + rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window ) // jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩 @@ -49,6 +51,20 @@ const ( subFieldVersion = "version" ) +// billingRateLimitKey generates the Redis key for API key rate limit cache. +func billingRateLimitKey(keyID int64) string { + return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID) +} + +const ( + rateLimitFieldUsage5h = "usage_5h" + rateLimitFieldUsage1d = "usage_1d" + rateLimitFieldUsage7d = "usage_7d" + rateLimitFieldWindow5h = "window_5h" + rateLimitFieldWindow1d = "window_1d" + rateLimitFieldWindow7d = "window_7d" +) + var ( deductBalanceScript = redis.NewScript(` local current = redis.call('GET', KEYS[1]) @@ -73,6 +89,21 @@ var ( redis.call('EXPIRE', KEYS[1], ARGV[2]) return 1 `) + + // updateRateLimitUsageScript atomically increments all three rate limit usage counters. + // Returns 0 if the key doesn't exist (cache miss), 1 on success. + updateRateLimitUsageScript = redis.NewScript(` + local exists = redis.call('EXISTS', KEYS[1]) + if exists == 0 then + return 0 + end + local cost = tonumber(ARGV[1]) + redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost) + redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost) + redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost) + redis.call('EXPIRE', KEYS[1], ARGV[2]) + return 1 + `) ) type billingCache struct { @@ -195,3 +226,69 @@ func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, key := billingSubKey(userID, groupID) return c.rdb.Del(ctx, key).Err() } + +func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) { + key := billingRateLimitKey(keyID) + result, err := c.rdb.HGetAll(ctx, key).Result() + if err != nil { + return nil, err + } + if len(result) == 0 { + return nil, redis.Nil + } + data := &service.APIKeyRateLimitCacheData{} + if v, ok := result[rateLimitFieldUsage5h]; ok { + data.Usage5h, _ = strconv.ParseFloat(v, 64) + } + if v, ok := result[rateLimitFieldUsage1d]; ok { + data.Usage1d, _ = strconv.ParseFloat(v, 64) + } + if v, ok := result[rateLimitFieldUsage7d]; ok { + data.Usage7d, _ = strconv.ParseFloat(v, 64) + } + if v, ok := result[rateLimitFieldWindow5h]; ok { + data.Window5h, _ = strconv.ParseInt(v, 10, 64) + } + if v, ok := result[rateLimitFieldWindow1d]; ok { + data.Window1d, _ = strconv.ParseInt(v, 10, 64) + } + if v, ok := result[rateLimitFieldWindow7d]; ok { + data.Window7d, _ = strconv.ParseInt(v, 10, 64) + } + return data, nil +} + +func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error { + if data == nil { + return nil + } + key := billingRateLimitKey(keyID) + fields := map[string]any{ + rateLimitFieldUsage5h: data.Usage5h, + rateLimitFieldUsage1d: data.Usage1d, + rateLimitFieldUsage7d: data.Usage7d, + rateLimitFieldWindow5h: data.Window5h, + rateLimitFieldWindow1d: data.Window1d, + rateLimitFieldWindow7d: data.Window7d, + } + pipe := c.rdb.Pipeline() + pipe.HSet(ctx, key, fields) + pipe.Expire(ctx, key, rateLimitCacheTTL) + _, err := pipe.Exec(ctx) + return err +} + +func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { + key := billingRateLimitKey(keyID) + _, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result() + if err != nil && !errors.Is(err, redis.Nil) { + log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err) + return err + } + return nil +} + +func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { + key := billingRateLimitKey(keyID) + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index 77764881..b754bd55 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -11,6 +11,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/util/logredact" @@ -28,11 +29,14 @@ func NewClaudeOAuthClient() service.ClaudeOAuthClient { type claudeOAuthService struct { baseURL string tokenURL string - clientFactory func(proxyURL string) *req.Client + clientFactory func(proxyURL string) (*req.Client, error) } func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return "", fmt.Errorf("create HTTP client: %w", err) + } var orgs []struct { UUID string `json:"uuid"` @@ -88,7 +92,10 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey } func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return "", fmt.Errorf("create HTTP client: %w", err) + } authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID) @@ -165,7 +172,10 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe } func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } // Parse code which may contain state in format "authCode#state" authCode := code @@ -223,7 +233,10 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod } func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } reqBody := map[string]any{ "grant_type": "refresh_token", @@ -253,16 +266,20 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro return &tokenResp, nil } -func createReqClient(proxyURL string) *req.Client { +func createReqClient(proxyURL string) (*req.Client, error) { // 禁用 CookieJar,确保每次授权都是干净的会话 client := req.C(). SetTimeout(60 * time.Second). ImpersonateChrome(). SetCookieJar(nil) // 禁用 CookieJar - if strings.TrimSpace(proxyURL) != "" { - client.SetProxyURL(strings.TrimSpace(proxyURL)) + trimmed, _, err := proxyurl.Parse(proxyURL) + if err != nil { + return nil, err + } + if trimmed != "" { + client.SetProxyURL(trimmed) } - return client + return client, nil } diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go index 7395c6d8..c6383033 100644 --- a/backend/internal/repository/claude_oauth_service_test.go +++ b/backend/internal/repository/claude_oauth_service_test.go @@ -91,7 +91,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.baseURL = "http://in-process" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "") @@ -169,7 +169,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.baseURL = "http://in-process" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "") @@ -276,7 +276,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.tokenURL = "http://in-process/token" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken) @@ -372,7 +372,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.tokenURL = "http://in-process/token" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } resp, err := s.client.RefreshToken(context.Background(), "rt", "") diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 1198f472..1264f6bb 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -8,6 +8,7 @@ import ( "net/http" "time" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -83,7 +84,7 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se AllowPrivateHosts: s.allowPrivateHosts, }) if err != nil { - client = &http.Client{Timeout: 30 * time.Second} + return nil, fmt.Errorf("create http client failed: %w", err) } resp, err = client.Do(req) @@ -95,7 +96,8 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + msg := fmt.Sprintf("API returned status %d: %s", resp.StatusCode, string(body)) + return nil, infraerrors.New(http.StatusInternalServerError, "UPSTREAM_ERROR", msg) } var usageResp service.ClaudeUsageResponse diff --git a/backend/internal/repository/claude_usage_service_test.go b/backend/internal/repository/claude_usage_service_test.go index 2e10f3e5..cbd0b6d3 100644 --- a/backend/internal/repository/claude_usage_service_test.go +++ b/backend/internal/repository/claude_usage_service_test.go @@ -50,7 +50,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { allowPrivateHosts: true, } - resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") + resp, err := s.fetcher.FetchUsage(context.Background(), "at", "") require.NoError(s.T(), err, "FetchUsage") require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch") require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch") @@ -112,6 +112,17 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { require.Error(s.T(), err, "expected error for cancelled context") } +func (s *ClaudeUsageServiceSuite) TestFetchUsage_InvalidProxyReturnsError() { + s.fetcher = &claudeUsageService{ + usageURL: "http://example.com", + allowPrivateHosts: true, + } + + _, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "create http client failed") +} + func TestClaudeUsageServiceSuite(t *testing.T) { suite.Run(t, new(ClaudeUsageServiceSuite)) } diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index a2552715..8732b2ce 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -147,17 +147,47 @@ var ( return 1 `) - // cleanupExpiredSlotsScript - remove expired slots - // KEYS[1] = concurrency:account:{accountID} - // ARGV[1] = TTL (seconds) + // cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位 + // KEYS[1] = 有序集合键 + // ARGV[1] = TTL(秒) cleanupExpiredSlotsScript = redis.NewScript(` - local key = KEYS[1] - local ttl = tonumber(ARGV[1]) - local timeResult = redis.call('TIME') - local now = tonumber(timeResult[1]) - local expireBefore = now - ttl - return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) - `) + local key = KEYS[1] + local ttl = tonumber(ARGV[1]) + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + local expireBefore = now - ttl + redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) + if redis.call('ZCARD', key) == 0 then + redis.call('DEL', key) + else + redis.call('EXPIRE', key, ttl) + end + return 1 + `) + + // startupCleanupScript 清理非当前进程前缀的槽位成员。 + // KEYS 是有序集合键列表,ARGV[1] 是当前进程前缀,ARGV[2] 是槽位 TTL。 + // 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key,否则刷新 EXPIRE。 + startupCleanupScript = redis.NewScript(` + local activePrefix = ARGV[1] + local slotTTL = tonumber(ARGV[2]) + local removed = 0 + for i = 1, #KEYS do + local key = KEYS[i] + local members = redis.call('ZRANGE', key, 0, -1) + for _, member in ipairs(members) do + if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then + removed = removed + redis.call('ZREM', key, member) + end + end + if redis.call('ZCARD', key) == 0 then + redis.call('DEL', key) + else + redis.call('EXPIRE', key, slotTTL) + end + end + return removed + `) ) type concurrencyCache struct { @@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou _, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result() return err } + +func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { + if activeRequestPrefix == "" { + return nil + } + + // 1. 清理有序集合中非当前进程前缀的成员 + slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"} + for _, pattern := range slotPatterns { + if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil { + return err + } + } + + // 2. 删除所有等待队列计数器(重启后计数器失效) + waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"} + for _, pattern := range waitPatterns { + if err := c.deleteKeysByPattern(ctx, pattern); err != nil { + return err + } + } + + return nil +} + +// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。 +func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error { + const scanCount = 200 + var cursor uint64 + for { + keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result() + if err != nil { + return fmt.Errorf("scan %s: %w", pattern, err) + } + if len(keys) > 0 { + _, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result() + if err != nil { + return fmt.Errorf("cleanup slots %s: %w", pattern, err) + } + } + cursor = nextCursor + if cursor == 0 { + break + } + } + return nil +} + +// deleteKeysByPattern 扫描匹配 pattern 的键并删除。 +func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error { + const scanCount = 200 + var cursor uint64 + for { + keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result() + if err != nil { + return fmt.Errorf("scan %s: %w", pattern, err) + } + if len(keys) > 0 { + if err := c.rdb.Del(ctx, keys...).Err(); err != nil { + return fmt.Errorf("del %s: %w", pattern, err) + } + } + cursor = nextCursor + if cursor == 0 { + break + } + } + return nil +} diff --git a/backend/internal/repository/concurrency_cache_integration_test.go b/backend/internal/repository/concurrency_cache_integration_test.go index 5983c832..5da94fc2 100644 --- a/backend/internal/repository/concurrency_cache_integration_test.go +++ b/backend/internal/repository/concurrency_cache_integration_test.go @@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct { cache service.ConcurrencyCache } +func TestConcurrencyCacheSuite(t *testing.T) { + suite.Run(t, new(ConcurrencyCacheSuite)) +} + func (s *ConcurrencyCacheSuite) SetupTest() { s.IntegrationRedisSuite.SetupTest() s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds())) @@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() { require.Equal(s.T(), 1, val, "expected account wait count 1") } -func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() { - accountID := int64(301) - waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) +func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() { + accountID := int64(901) + userID := int64(902) + accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) + userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID) + userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) + accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) - require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key") + now := time.Now().Unix() + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey, + redis.Z{Score: float64(now), Member: "oldproc-1"}, + redis.Z{Score: float64(now), Member: "keep-1"}, + ).Err()) + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey, + redis.Z{Score: float64(now), Member: "oldproc-2"}, + redis.Z{Score: float64(now), Member: "keep-2"}, + ).Err()) + require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err()) + require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err()) - val, err := s.rdb.Get(s.ctx, waitKey).Int() - if !errors.Is(err, redis.Nil) { - require.NoError(s.T(), err, "Get waitKey") - } - require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty") + require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-")) + + accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Equal(s.T(), []string{"keep-1"}, accountMembers) + + userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Equal(s.T(), []string{"keep-2"}, userMembers) + + _, err = s.rdb.Get(s.ctx, userWaitKey).Result() + require.True(s.T(), errors.Is(err, redis.Nil)) + + _, err = s.rdb.Get(s.ctx, accountWaitKey).Result() + require.True(s.T(), errors.Is(err, redis.Nil)) } func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { @@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() { require.Equal(s.T(), 2, cur) } -func TestConcurrencyCacheSuite(t *testing.T) { - suite.Run(t, new(ConcurrencyCacheSuite)) +func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() { + accountID := int64(901) + userID := int64(902) + accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) + userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID) + userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) + accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) + + now := float64(time.Now().Unix()) + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, + redis.Z{Score: now, Member: "oldproc-1"}, + redis.Z{Score: now, Member: "activeproc-1"}, + ).Err()) + require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err()) + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey, + redis.Z{Score: now, Member: "oldproc-2"}, + redis.Z{Score: now, Member: "activeproc-2"}, + ).Err()) + require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err()) + require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err()) + require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err()) + + require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-")) + + accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Equal(s.T(), []string{"activeproc-1"}, accountMembers) + + userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Equal(s.T(), []string{"activeproc-2"}, userMembers) + + _, err = s.rdb.Get(s.ctx, userWaitKey).Result() + require.ErrorIs(s.T(), err, redis.Nil) + _, err = s.rdb.Get(s.ctx, accountWaitKey).Result() + require.ErrorIs(s.T(), err, redis.Nil) +} + +func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() { + accountID := int64(903) + accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err()) + require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err()) + + require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-")) + + exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result() + require.NoError(s.T(), err) + require.EqualValues(s.T(), 0, exists) } diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go index 59bbd6a3..e82a73a3 100644 --- a/backend/internal/repository/dashboard_aggregation_repo.go +++ b/backend/internal/repository/dashboard_aggregation_repo.go @@ -17,6 +17,9 @@ type dashboardAggregationRepository struct { sql sqlExecutor } +const usageLogsCleanupBatchSize = 10000 +const usageBillingDedupCleanupBatchSize = 10000 + // NewDashboardAggregationRepository 创建仪表盘预聚合仓储。 func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository { if sqlDB == nil { @@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool { } func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error { + if r == nil || r.sql == nil { + return nil + } loc := timezone.Location() startLocal := start.In(loc) endLocal := end.In(loc) @@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta dayEnd = dayEnd.Add(24 * time.Hour) } + if db, ok := r.sql.(*sql.DB); ok { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + txRepo := newDashboardAggregationRepositoryWithSQL(tx) + if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() + } + return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd) +} + +func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error { // 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。 if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil { return err @@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c if isPartitioned { return r.dropUsageLogsPartitions(ctx, cutoff) } - _, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC()) - return err + for { + res, err := r.sql.ExecContext(ctx, ` + WITH victims AS ( + SELECT ctid + FROM usage_logs + WHERE created_at < $1 + LIMIT $2 + ) + DELETE FROM usage_logs + WHERE ctid IN (SELECT ctid FROM victims) + `, cutoff.UTC(), usageLogsCleanupBatchSize) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected < usageLogsCleanupBatchSize { + return nil + } + } +} + +func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + for { + res, err := r.sql.ExecContext(ctx, ` + WITH victims AS ( + SELECT ctid, request_id, api_key_id, request_fingerprint, created_at + FROM usage_billing_dedup + WHERE created_at < $1 + LIMIT $2 + ), archived AS ( + INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at) + SELECT request_id, api_key_id, request_fingerprint, created_at + FROM victims + ON CONFLICT (request_id, api_key_id) DO NOTHING + ) + DELETE FROM usage_billing_dedup + WHERE ctid IN (SELECT ctid FROM victims) + `, cutoff.UTC(), usageBillingDedupCleanupBatchSize) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected < usageBillingDedupCleanupBatchSize { + return nil + } + } } func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { diff --git a/backend/internal/repository/ent.go b/backend/internal/repository/ent.go index 5f3f5a84..64d32192 100644 --- a/backend/internal/repository/ent.go +++ b/backend/internal/repository/ent.go @@ -89,6 +89,10 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { _ = client.Close() return nil, nil, err } + if err := ensureSimpleModeAdminConcurrency(seedCtx, client); err != nil { + _ = client.Close() + return nil, nil, err + } } return client, drv.DB(), nil diff --git a/backend/internal/repository/fixtures_integration_test.go b/backend/internal/repository/fixtures_integration_test.go index 23adb4e4..80b9cab6 100644 --- a/backend/internal/repository/fixtures_integration_test.go +++ b/backend/internal/repository/fixtures_integration_test.go @@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se SetKey(k.Key). SetName(k.Name). SetStatus(k.Status) + if k.Quota != 0 { + create.SetQuota(k.Quota) + } + if k.QuotaUsed != 0 { + create.SetQuotaUsed(k.QuotaUsed) + } + if k.RateLimit5h != 0 { + create.SetRateLimit5h(k.RateLimit5h) + } + if k.RateLimit1d != 0 { + create.SetRateLimit1d(k.RateLimit1d) + } + if k.RateLimit7d != 0 { + create.SetRateLimit7d(k.RateLimit7d) + } + if k.Usage5h != 0 { + create.SetUsage5h(k.Usage5h) + } + if k.Usage1d != 0 { + create.SetUsage1d(k.Usage1d) + } + if k.Usage7d != 0 { + create.SetUsage7d(k.Usage7d) + } + if k.Window5hStart != nil { + create.SetWindow5hStart(*k.Window5hStart) + } + if k.Window1dStart != nil { + create.SetWindow1dStart(*k.Window1dStart) + } + if k.Window7dStart != nil { + create.SetWindow7dStart(*k.Window7dStart) + } + if k.ExpiresAt != nil { + create.SetExpiresAt(*k.ExpiresAt) + } if k.GroupID != nil { create.SetGroupID(*k.GroupID) } diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go index 8b7fe625..eb14f313 100644 --- a/backend/internal/repository/gemini_oauth_client.go +++ b/backend/internal/repository/gemini_oauth_client.go @@ -26,7 +26,10 @@ func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient { } func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) { - client := createGeminiReqClient(proxyURL) + client, err := createGeminiReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } // Use different OAuth clients based on oauthType: // - code_assist: always use built-in Gemini CLI OAuth client (public) @@ -72,7 +75,10 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c } func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { - client := createGeminiReqClient(proxyURL) + client, err := createGeminiReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } oauthCfgInput := geminicli.OAuthConfig{ ClientID: c.cfg.Gemini.OAuth.ClientID, @@ -111,7 +117,7 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh return &tokenResp, nil } -func createGeminiReqClient(proxyURL string) *req.Client { +func createGeminiReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 60 * time.Second, diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index 4f63280d..b5bc6497 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -26,7 +26,11 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo } var out geminicli.LoadCodeAssistResponse - resp, err := createGeminiCliReqClient(proxyURL).R(). + client, err := createGeminiCliReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } + resp, err := client.R(). SetContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Content-Type", "application/json"). @@ -66,7 +70,11 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody) var out geminicli.OnboardUserResponse - resp, err := createGeminiCliReqClient(proxyURL).R(). + client, err := createGeminiCliReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } + resp, err := client.R(). SetContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Content-Type", "application/json"). @@ -98,7 +106,7 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken return &out, nil } -func createGeminiCliReqClient(proxyURL string) *req.Client { +func createGeminiCliReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 30 * time.Second, diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index a7d2863d..74156164 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -5,8 +5,10 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net/http" "os" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" @@ -24,13 +26,19 @@ type githubReleaseClientError struct { // NewGitHubReleaseClient 创建 GitHub Release 客户端 // proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议 +// 代理配置失败时行为由 allowDirectOnProxyError 控制: +// - false(默认):返回错误占位客户端,禁止回退到直连 +// - true:回退到直连(仅限管理员显式开启) func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient { + // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据, + // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。 sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { - if proxyURL != "" && !allowDirectOnProxyError { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy client init failed, all requests will fail", "service", "github_release", "error", err) return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} } sharedClient = &http.Client{Timeout: 30 * time.Second} @@ -42,7 +50,8 @@ func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) servi ProxyURL: proxyURL, }) if err != nil { - if proxyURL != "" && !allowDirectOnProxyError { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy download client init failed, all requests will fail", "service", "github_release", "error", err) return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} } downloadClient = &http.Client{Timeout: 10 * time.Minute} diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 4edc8534..c195f1f1 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -59,7 +59,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetMcpXMLInject(groupIn.MCPXMLInject). - SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) + SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes). + SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). + SetDefaultMappedModel(groupIn.DefaultMappedModel) // 设置模型路由配置 if groupIn.ModelRouting != nil { @@ -125,7 +127,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetMcpXMLInject(groupIn.MCPXMLInject). - SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) + SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes). + SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). + SetDefaultMappedModel(groupIn.DefaultMappedModel) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 if groupIn.DailyLimitUSD != nil { diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index b0f15f19..a4674c1a 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/Wei-Shaw/sub2api/internal/service" @@ -235,7 +236,10 @@ func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID in // TLS 指纹客户端使用独立的缓存键,与普通客户端隔离 func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { isolation := s.getIsolationMode() - proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) + if err != nil { + return nil, err + } // TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀 cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID) poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls" @@ -373,9 +377,8 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac // - proxy: 按代理地址隔离,同一代理共享客户端 // - account: 按账户隔离,同一账户共享客户端(代理变更时重建) // - account_proxy: 按账户+代理组合隔离,最细粒度 -func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { - entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) - return entry +func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { + return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) } // getClientEntry 获取或创建客户端条目 @@ -385,7 +388,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a // 获取隔离模式 isolation := s.getIsolationMode() // 标准化代理 URL 并解析 - proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) + if err != nil { + return nil, err + } // 构建缓存键(根据隔离策略不同) cacheKey := buildCacheKey(isolation, proxyKey, accountID) // 构建连接池配置键(用于检测配置变更) @@ -680,17 +686,18 @@ func buildCacheKey(isolation, proxyKey string, accountID int64) string { // - raw: 原始代理 URL 字符串 // // 返回: -// - string: 标准化的代理键(空或解析失败返回 "direct") -// - *url.URL: 解析后的 URL(空或解析失败返回 nil) -func normalizeProxyURL(raw string) (string, *url.URL) { - proxyURL := strings.TrimSpace(raw) - if proxyURL == "" { - return directProxyKey, nil - } - parsed, err := url.Parse(proxyURL) +// - string: 标准化的代理键(空返回 "direct") +// - *url.URL: 解析后的 URL(空返回 nil) +// - error: 非空代理 URL 解析失败时返回错误(禁止回退到直连) +func normalizeProxyURL(raw string) (string, *url.URL, error) { + _, parsed, err := proxyurl.Parse(raw) if err != nil { - return directProxyKey, nil + return "", nil, err } + if parsed == nil { + return directProxyKey, nil, nil + } + // 规范化:小写 scheme/host,去除路径和查询参数 parsed.Scheme = strings.ToLower(parsed.Scheme) parsed.Host = strings.ToLower(parsed.Host) parsed.Path = "" @@ -710,7 +717,7 @@ func normalizeProxyURL(raw string) (string, *url.URL) { parsed.Host = hostname } } - return parsed.String(), parsed + return parsed.String(), parsed, nil } // defaultPoolSettings 获取默认连接池配置 diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go index 1e7430a3..89892b3b 100644 --- a/backend/internal/repository/http_upstream_benchmark_test.go +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -59,7 +59,10 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { // 模拟优化后的行为,从缓存获取客户端 b.Run("复用", func(b *testing.B) { // 预热:确保客户端已缓存 - entry := svc.getOrCreateClient(proxyURL, 1, 1) + entry, err := svc.getOrCreateClient(proxyURL, 1, 1) + if err != nil { + b.Fatalf("getOrCreateClient: %v", err) + } client := entry.client b.ResetTimer() // 重置计时器,排除预热时间 for i := 0; i < b.N; i++ { diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index fbe44c5e..b3268463 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -44,7 +44,7 @@ func (s *HTTPUpstreamSuite) newService() *httpUpstreamService { // 验证未配置时使用 300 秒默认值 func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { svc := s.newService() - entry := svc.getOrCreateClient("", 0, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") @@ -55,25 +55,27 @@ func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() { s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7} svc := s.newService() - entry := svc.getOrCreateClient("", 0, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") } -// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退 -// 验证解析失败时回退到直连模式 -func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() { +// TestGetOrCreateClient_InvalidURLReturnsError 测试无效代理 URL 返回错误 +// 验证解析失败时拒绝回退到直连模式 +func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLReturnsError() { svc := s.newService() - entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1) - require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback") + _, err := svc.getClientEntry("://bad-proxy-url", 1, 1, false, false) + require.Error(s.T(), err, "expected error for invalid proxy URL") } // TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化 // 验证等价地址能够映射到同一缓存键 func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() { - key1, _ := normalizeProxyURL("http://proxy.local:8080") - key2, _ := normalizeProxyURL("http://proxy.local:8080/") + key1, _, err1 := normalizeProxyURL("http://proxy.local:8080") + require.NoError(s.T(), err1) + key2, _, err2 := normalizeProxyURL("http://proxy.local:8080/") + require.NoError(s.T(), err2) require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match") } @@ -171,8 +173,8 @@ func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 同一代理,不同账户 - entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 2, 3) require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池") require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端") } @@ -183,8 +185,8 @@ func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy} svc := s.newService() // 同一账户,不同代理 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3) require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理") require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端") } @@ -195,8 +197,8 @@ func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 同一账户,先后使用不同代理 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3) require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池") require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池") require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理") @@ -208,7 +210,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 账户并发数为 12 - entry := svc.getOrCreateClient("", 1, 12) + entry := mustGetOrCreateClient(s.T(), svc, "", 1, 12) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") // 连接池参数应与并发数一致 @@ -228,7 +230,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() { } svc := s.newService() // 账户并发数为 0,应使用全局配置 - entry := svc.getOrCreateClient("", 1, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 1, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch") @@ -245,12 +247,12 @@ func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() { } svc := s.newService() // 创建两个客户端,设置不同的最后使用时间 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 1) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 2, 1) atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久 atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano()) // 创建第三个客户端,触发淘汰 - _ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1) + _ = mustGetOrCreateClient(s.T(), svc, "http://proxy-c:8080", 3, 1) require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内") require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理") @@ -264,12 +266,12 @@ func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() { ClientIdleTTLSeconds: 1, // 1 秒空闲超时 } svc := s.newService() - entry1 := svc.getOrCreateClient("", 1, 1) + entry1 := mustGetOrCreateClient(s.T(), svc, "", 1, 1) // 设置为很久之前使用,但有活跃请求 atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano()) atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求 // 创建新客户端,触发淘汰检查 - _ = svc.getOrCreateClient("", 2, 1) + _, _ = svc.getOrCreateClient("", 2, 1) require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收") } @@ -279,6 +281,14 @@ func TestHTTPUpstreamSuite(t *testing.T) { suite.Run(t, new(HTTPUpstreamSuite)) } +// mustGetOrCreateClient 测试辅助函数,调用 getOrCreateClient 并断言无错误 +func mustGetOrCreateClient(t *testing.T, svc *httpUpstreamService, proxyURL string, accountID int64, concurrency int) *upstreamClientEntry { + t.Helper() + entry, err := svc.getOrCreateClient(proxyURL, accountID, concurrency) + require.NoError(t, err, "getOrCreateClient(%q, %d, %d)", proxyURL, accountID, concurrency) + return entry +} + // hasEntry 检查客户端是否存在于缓存中 // 辅助函数,用于验证淘汰逻辑 func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool { diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index a60ba294..9cf3b392 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -66,6 +66,13 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {}, }, }, + "061_add_usage_log_request_type.sql": { + fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", + acceptedDBChecksum: map[string]struct{}{ + "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {}, + "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {}, + }, + }, } // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go index 54f5b0ec..6c3ad725 100644 --- a/backend/internal/repository/migrations_runner_checksum_test.go +++ b/backend/internal/repository/migrations_runner_checksum_test.go @@ -25,6 +25,24 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { require.False(t, ok) }) + t.Run("061历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "061_add_usage_log_request_type.sql", + "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", + "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", + ) + require.True(t, ok) + }) + + t.Run("061第二个历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "061_add_usage_log_request_type.sql", + "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3", + "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", + ) + require.True(t, ok) + }) + t.Run("非白名单迁移不兼容", func(t *testing.T) { ok := isMigrationChecksumCompatible( "001_init.sql", diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index 72422d18..dd3019bb 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false) requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false) + // usage_billing_dedup: billing idempotency narrow table + var usageBillingDedupRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass)) + require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist") + requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false) + requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key") + requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin") + + var usageBillingDedupArchiveRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass)) + require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist") + requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false) + requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey") + // settings table should exist var settingsRegclass sql.NullString require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass)) @@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false) } +func requireIndex(t *testing.T, tx *sql.Tx, table, index string) { + t.Helper() + + var exists bool + err := tx.QueryRowContext(context.Background(), ` +SELECT EXISTS ( + SELECT 1 + FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = $1 + AND indexname = $2 +) +`, table, index).Scan(&exists) + require.NoError(t, err, "query pg_indexes for %s.%s", table, index) + require.True(t, exists, "expected index %s on %s", index, table) +} + func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { t.Helper() diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index 3e155971..dca0b612 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -23,7 +23,10 @@ type openaiOAuthService struct { } func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { - client := createOpenAIReqClient(proxyURL) + client, err := createOpenAIReqClient(proxyURL) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err) + } if redirectURI == "" { redirectURI = openai.DefaultRedirectURI @@ -74,7 +77,10 @@ func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refre } func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) { - client := createOpenAIReqClient(proxyURL) + client, err := createOpenAIReqClient(proxyURL) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err) + } formData := url.Values{} formData.Set("grant_type", "refresh_token") @@ -102,7 +108,7 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre return &tokenResp, nil } -func createOpenAIReqClient(proxyURL string) *req.Client { +func createOpenAIReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 120 * time.Second, diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index 989573f2..02ca1a3b 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -16,19 +16,7 @@ type opsRepository struct { db *sql.DB } -func NewOpsRepository(db *sql.DB) service.OpsRepository { - return &opsRepository{db: db} -} - -func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) { - if r == nil || r.db == nil { - return 0, fmt.Errorf("nil ops repository") - } - if input == nil { - return 0, fmt.Errorf("nil input") - } - - q := ` +const insertOpsErrorLogSQL = ` INSERT INTO ops_error_logs ( request_id, client_request_id, @@ -70,12 +58,77 @@ INSERT INTO ops_error_logs ( created_at ) VALUES ( $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 -) RETURNING id` +)` + +func NewOpsRepository(db *sql.DB) service.OpsRepository { + return &opsRepository{db: db} +} + +func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("nil ops repository") + } + if input == nil { + return 0, fmt.Errorf("nil input") + } var id int64 err := r.db.QueryRowContext( ctx, - q, + insertOpsErrorLogSQL+" RETURNING id", + opsInsertErrorLogArgs(input)..., + ).Scan(&id) + if err != nil { + return 0, err + } + return id, nil +} + +func (r *opsRepository) BatchInsertErrorLogs(ctx context.Context, inputs []*service.OpsInsertErrorLogInput) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("nil ops repository") + } + if len(inputs) == 0 { + return 0, nil + } + + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return 0, err + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + + stmt, err := tx.PrepareContext(ctx, insertOpsErrorLogSQL) + if err != nil { + return 0, err + } + defer func() { + _ = stmt.Close() + }() + + var inserted int64 + for _, input := range inputs { + if input == nil { + continue + } + if _, err = stmt.ExecContext(ctx, opsInsertErrorLogArgs(input)...); err != nil { + return inserted, err + } + inserted++ + } + + if err = tx.Commit(); err != nil { + return inserted, err + } + return inserted, nil +} + +func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any { + return []any{ opsNullString(input.RequestID), opsNullString(input.ClientRequestID), opsNullInt64(input.UserID), @@ -114,11 +167,7 @@ INSERT INTO ops_error_logs ( input.IsRetryable, input.RetryCount, input.CreatedAt, - ).Scan(&id) - if err != nil { - return 0, err } - return id, nil } func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) { diff --git a/backend/internal/repository/ops_repo_latency_histogram_buckets.go b/backend/internal/repository/ops_repo_latency_histogram_buckets.go index cd5bed37..e56903f1 100644 --- a/backend/internal/repository/ops_repo_latency_histogram_buckets.go +++ b/backend/internal/repository/ops_repo_latency_histogram_buckets.go @@ -35,12 +35,12 @@ func latencyHistogramRangeCaseExpr(column string) string { if b.upperMs <= 0 { continue } - _, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label)) + fmt.Fprintf(&sb, "\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label) } // Default bucket. last := latencyHistogramBuckets[len(latencyHistogramBuckets)-1] - _, _ = sb.WriteString(fmt.Sprintf("\tELSE '%s'\n", last.label)) + fmt.Fprintf(&sb, "\tELSE '%s'\n", last.label) _, _ = sb.WriteString("END") return sb.String() } @@ -54,11 +54,11 @@ func latencyHistogramRangeOrderCaseExpr(column string) string { if b.upperMs <= 0 { continue } - _, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN %d\n", column, b.upperMs, order)) + fmt.Fprintf(&sb, "\tWHEN %s < %d THEN %d\n", column, b.upperMs, order) order++ } - _, _ = sb.WriteString(fmt.Sprintf("\tELSE %d\n", order)) + fmt.Fprintf(&sb, "\tELSE %d\n", order) _, _ = sb.WriteString("END") return sb.String() } diff --git a/backend/internal/repository/ops_write_pressure_integration_test.go b/backend/internal/repository/ops_write_pressure_integration_test.go new file mode 100644 index 00000000..ebb7a842 --- /dev/null +++ b/backend/internal/repository/ops_write_pressure_integration_test.go @@ -0,0 +1,79 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) { + ctx := context.Background() + _, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY") + + repo := NewOpsRepository(integrationDB).(*opsRepository) + now := time.Now().UTC() + inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{ + { + RequestID: "batch-ops-1", + ErrorPhase: "upstream", + ErrorType: "upstream_error", + Severity: "error", + StatusCode: 429, + ErrorMessage: "rate limited", + CreatedAt: now, + }, + { + RequestID: "batch-ops-2", + ErrorPhase: "internal", + ErrorType: "api_error", + Severity: "error", + StatusCode: 500, + ErrorMessage: "internal error", + CreatedAt: now.Add(time.Millisecond), + }, + }) + require.NoError(t, err) + require.EqualValues(t, 2, inserted) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count)) + require.Equal(t, 2, count) +} + +func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) { + ctx := context.Background() + _, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY") + + accountID := int64(12345) + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil)) + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil)) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count)) + require.Equal(t, 1, count) + + time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond) + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil)) + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count)) + require.Equal(t, 2, count) +} + +func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) { + ctx := context.Background() + _, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY") + + accountID := int64(67890) + payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}} + payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}} + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1)) + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2)) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count)) + require.Equal(t, 2, count) +} diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go index 07d796b8..ee8e1749 100644 --- a/backend/internal/repository/pricing_service.go +++ b/backend/internal/repository/pricing_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "log/slog" "net/http" "strings" "time" @@ -16,14 +17,37 @@ type pricingRemoteClient struct { httpClient *http.Client } +// pricingRemoteClientError 代理初始化失败时的错误占位客户端 +// 所有请求直接返回初始化错误,禁止回退到直连 +type pricingRemoteClientError struct { + err error +} + +func (c *pricingRemoteClientError) FetchPricingJSON(_ context.Context, _ string) ([]byte, error) { + return nil, c.err +} + +func (c *pricingRemoteClientError) FetchHashText(_ context.Context, _ string) (string, error) { + return "", c.err +} + // NewPricingRemoteClient 创建定价数据远程客户端 // proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议 -func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient { +// 代理配置失败时行为由 allowDirectOnProxyError 控制: +// - false(默认):返回错误占位客户端,禁止回退到直连 +// - true:回退到直连(仅限管理员显式开启) +func NewPricingRemoteClient(proxyURL string, allowDirectOnProxyError bool) service.PricingRemoteClient { + // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据, + // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。 sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy client init failed, all requests will fail", "service", "pricing", "error", err) + return &pricingRemoteClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} + } sharedClient = &http.Client{Timeout: 30 * time.Second} } return &pricingRemoteClient{ diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go index 6ea11211..ef2f214b 100644 --- a/backend/internal/repository/pricing_service_test.go +++ b/backend/internal/repository/pricing_service_test.go @@ -19,7 +19,7 @@ type PricingServiceSuite struct { func (s *PricingServiceSuite) SetupTest() { s.ctx = context.Background() - client, ok := NewPricingRemoteClient("").(*pricingRemoteClient) + client, ok := NewPricingRemoteClient("", false).(*pricingRemoteClient) require.True(s.T(), ok, "type assertion failed") s.client = client } @@ -140,6 +140,22 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() { require.Error(s.T(), err) } +func TestNewPricingRemoteClient_InvalidProxy_NoFallback(t *testing.T) { + client := NewPricingRemoteClient("://bad", false) + _, ok := client.(*pricingRemoteClientError) + require.True(t, ok, "should return error client when proxy is invalid and fallback disabled") + + _, err := client.FetchPricingJSON(context.Background(), "http://example.com") + require.Error(t, err) + require.Contains(t, err.Error(), "proxy client init failed") +} + +func TestNewPricingRemoteClient_InvalidProxy_WithFallback(t *testing.T) { + client := NewPricingRemoteClient("://bad", true) + _, ok := client.(*pricingRemoteClient) + require.True(t, ok, "should fallback to direct client when allowed") +} + func TestPricingServiceSuite(t *testing.T) { suite.Run(t, new(PricingServiceSuite)) } diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 54de2897..b4aeab71 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -66,7 +66,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s ProxyURL: proxyURL, Timeout: defaultProxyProbeTimeout, InsecureSkipVerify: s.insecureSkipVerify, - ProxyStrict: true, ValidateResolvedIP: s.validateResolvedIP, AllowPrivateHosts: s.allowPrivateHosts, }) diff --git a/backend/internal/repository/req_client_pool.go b/backend/internal/repository/req_client_pool.go index af71a7ee..32501f7b 100644 --- a/backend/internal/repository/req_client_pool.go +++ b/backend/internal/repository/req_client_pool.go @@ -6,6 +6,8 @@ import ( "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/imroc/req/v3" ) @@ -33,11 +35,11 @@ var sharedReqClients sync.Map // getSharedReqClient 获取共享的 req 客户端实例 // 性能优化:相同配置复用同一客户端,避免重复创建 -func getSharedReqClient(opts reqClientOptions) *req.Client { +func getSharedReqClient(opts reqClientOptions) (*req.Client, error) { key := buildReqClientKey(opts) if cached, ok := sharedReqClients.Load(key); ok { if c, ok := cached.(*req.Client); ok { - return c + return c, nil } } @@ -48,15 +50,19 @@ func getSharedReqClient(opts reqClientOptions) *req.Client { if opts.Impersonate { client = client.ImpersonateChrome() } - if strings.TrimSpace(opts.ProxyURL) != "" { - client.SetProxyURL(strings.TrimSpace(opts.ProxyURL)) + trimmed, _, err := proxyurl.Parse(opts.ProxyURL) + if err != nil { + return nil, err + } + if trimmed != "" { + client.SetProxyURL(trimmed) } actual, _ := sharedReqClients.LoadOrStore(key, client) if c, ok := actual.(*req.Client); ok { - return c + return c, nil } - return client + return client, nil } func buildReqClientKey(opts reqClientOptions) string { @@ -67,3 +73,14 @@ func buildReqClientKey(opts reqClientOptions) string { opts.ForceHTTP2, ) } + +// CreatePrivacyReqClient creates an HTTP client for OpenAI privacy settings API +// This is exported for use by OpenAIPrivacyService +// Uses Chrome TLS fingerprint impersonation to bypass Cloudflare checks +func CreatePrivacyReqClient(proxyURL string) (*req.Client, error) { + return getSharedReqClient(reqClientOptions{ + ProxyURL: proxyURL, + Timeout: 30 * time.Second, + Impersonate: true, // Enable Chrome TLS fingerprint impersonation + }) +} diff --git a/backend/internal/repository/req_client_pool_test.go b/backend/internal/repository/req_client_pool_test.go index 904ed4d6..9067d012 100644 --- a/backend/internal/repository/req_client_pool_test.go +++ b/backend/internal/repository/req_client_pool_test.go @@ -26,11 +26,13 @@ func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) { ProxyURL: "http://proxy.local:8080", Timeout: time.Second, } - clientDefault := getSharedReqClient(base) + clientDefault, err := getSharedReqClient(base) + require.NoError(t, err) force := base force.ForceHTTP2 = true - clientForce := getSharedReqClient(force) + clientForce, err := getSharedReqClient(force) + require.NoError(t, err) require.NotSame(t, clientDefault, clientForce) require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force)) @@ -42,8 +44,10 @@ func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) { ProxyURL: "http://proxy.local:8080", Timeout: 2 * time.Second, } - first := getSharedReqClient(opts) - second := getSharedReqClient(opts) + first, err := getSharedReqClient(opts) + require.NoError(t, err) + second, err := getSharedReqClient(opts) + require.NoError(t, err) require.Same(t, first, second) } @@ -56,7 +60,8 @@ func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) { key := buildReqClientKey(opts) sharedReqClients.Store(key, "invalid") - client := getSharedReqClient(opts) + client, err := getSharedReqClient(opts) + require.NoError(t, err) require.NotNil(t, client) loaded, ok := sharedReqClients.Load(key) @@ -71,20 +76,45 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) { Timeout: 4 * time.Second, Impersonate: true, } - client := getSharedReqClient(opts) + client, err := getSharedReqClient(opts) + require.NoError(t, err) require.NotNil(t, client) require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts)) } +func TestGetSharedReqClient_InvalidProxyURL(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: "://missing-scheme", + Timeout: time.Second, + } + _, err := getSharedReqClient(opts) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid proxy URL") +} + +func TestGetSharedReqClient_ProxyURLMissingHost(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: "http://", + Timeout: time.Second, + } + _, err := getSharedReqClient(opts) + require.Error(t, err) + require.Contains(t, err.Error(), "proxy URL missing host") +} + func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) { sharedReqClients = sync.Map{} - client := createOpenAIReqClient("http://proxy.local:8080") + client, err := createOpenAIReqClient("http://proxy.local:8080") + require.NoError(t, err) require.Equal(t, 120*time.Second, client.GetClient().Timeout) } func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) { sharedReqClients = sync.Map{} - client := createGeminiReqClient("http://proxy.local:8080") + client, err := createGeminiReqClient("http://proxy.local:8080") + require.NoError(t, err) require.Equal(t, "", forceHTTPVersion(t, client)) } diff --git a/backend/internal/repository/scheduled_test_repo.go b/backend/internal/repository/scheduled_test_repo.go new file mode 100644 index 00000000..c03d1df9 --- /dev/null +++ b/backend/internal/repository/scheduled_test_repo.go @@ -0,0 +1,183 @@ +package repository + +import ( + "context" + "database/sql" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// --- Plan Repository --- + +type scheduledTestPlanRepository struct { + db *sql.DB +} + +func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanRepository { + return &scheduledTestPlanRepository{db: db} +} + +func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) { + row := r.db.QueryRowContext(ctx, ` + INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, auto_recover, next_run_at, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW()) + RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at + `, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt) + return scanPlan(row) +} + +func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at + FROM scheduled_test_plans WHERE id = $1 + `, id) + return scanPlan(row) +} + +func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at + FROM scheduled_test_plans WHERE account_id = $1 + ORDER BY created_at DESC + `, accountID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + return scanPlans(rows) +} + +func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at + FROM scheduled_test_plans + WHERE enabled = true AND next_run_at <= $1 + ORDER BY next_run_at ASC + `, now) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + return scanPlans(rows) +} + +func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) { + row := r.db.QueryRowContext(ctx, ` + UPDATE scheduled_test_plans + SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, auto_recover = $6, next_run_at = $7, updated_at = NOW() + WHERE id = $1 + RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at + `, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt) + return scanPlan(row) +} + +func (r *scheduledTestPlanRepository) Delete(ctx context.Context, id int64) error { + _, err := r.db.ExecContext(ctx, `DELETE FROM scheduled_test_plans WHERE id = $1`, id) + return err +} + +func (r *scheduledTestPlanRepository) UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error { + _, err := r.db.ExecContext(ctx, ` + UPDATE scheduled_test_plans SET last_run_at = $2, next_run_at = $3, updated_at = NOW() WHERE id = $1 + `, id, lastRunAt, nextRunAt) + return err +} + +// --- Result Repository --- + +type scheduledTestResultRepository struct { + db *sql.DB +} + +func NewScheduledTestResultRepository(db *sql.DB) service.ScheduledTestResultRepository { + return &scheduledTestResultRepository{db: db} +} + +func (r *scheduledTestResultRepository) Create(ctx context.Context, result *service.ScheduledTestResult) (*service.ScheduledTestResult, error) { + row := r.db.QueryRowContext(ctx, ` + INSERT INTO scheduled_test_results (plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) + RETURNING id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at + `, result.PlanID, result.Status, result.ResponseText, result.ErrorMessage, result.LatencyMs, result.StartedAt, result.FinishedAt) + + out := &service.ScheduledTestResult{} + if err := row.Scan( + &out.ID, &out.PlanID, &out.Status, &out.ResponseText, &out.ErrorMessage, + &out.LatencyMs, &out.StartedAt, &out.FinishedAt, &out.CreatedAt, + ); err != nil { + return nil, err + } + return out, nil +} + +func (r *scheduledTestResultRepository) ListByPlanID(ctx context.Context, planID int64, limit int) ([]*service.ScheduledTestResult, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at + FROM scheduled_test_results + WHERE plan_id = $1 + ORDER BY created_at DESC + LIMIT $2 + `, planID, limit) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var results []*service.ScheduledTestResult + for rows.Next() { + r := &service.ScheduledTestResult{} + if err := rows.Scan( + &r.ID, &r.PlanID, &r.Status, &r.ResponseText, &r.ErrorMessage, + &r.LatencyMs, &r.StartedAt, &r.FinishedAt, &r.CreatedAt, + ); err != nil { + return nil, err + } + results = append(results, r) + } + return results, rows.Err() +} + +func (r *scheduledTestResultRepository) PruneOldResults(ctx context.Context, planID int64, keepCount int) error { + _, err := r.db.ExecContext(ctx, ` + DELETE FROM scheduled_test_results + WHERE id IN ( + SELECT id FROM ( + SELECT id, ROW_NUMBER() OVER (PARTITION BY plan_id ORDER BY created_at DESC) AS rn + FROM scheduled_test_results + WHERE plan_id = $1 + ) ranked + WHERE rn > $2 + ) + `, planID, keepCount) + return err +} + +// --- scan helpers --- + +type scannable interface { + Scan(dest ...any) error +} + +func scanPlan(row scannable) (*service.ScheduledTestPlan, error) { + p := &service.ScheduledTestPlan{} + if err := row.Scan( + &p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults, &p.AutoRecover, + &p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt, + ); err != nil { + return nil, err + } + return p, nil +} + +func scanPlans(rows *sql.Rows) ([]*service.ScheduledTestPlan, error) { + var plans []*service.ScheduledTestPlan + for rows.Next() { + p, err := scanPlan(rows) + if err != nil { + return nil, err + } + plans = append(plans, p) + } + return plans, rows.Err() +} diff --git a/backend/internal/repository/scheduler_outbox_repo.go b/backend/internal/repository/scheduler_outbox_repo.go index d7bc97da..4b9a9f58 100644 --- a/backend/internal/repository/scheduler_outbox_repo.go +++ b/backend/internal/repository/scheduler_outbox_repo.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "time" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -12,6 +13,8 @@ type schedulerOutboxRepository struct { db *sql.DB } +const schedulerOutboxDedupWindow = time.Second + func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository { return &schedulerOutboxRepository{db: db} } @@ -88,9 +91,37 @@ func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType str } payloadArg = encoded } - _, err := exec.ExecContext(ctx, ` + query := ` INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload) VALUES ($1, $2, $3, $4) - `, eventType, accountID, groupID, payloadArg) + ` + args := []any{eventType, accountID, groupID, payloadArg} + if schedulerOutboxEventSupportsDedup(eventType) { + query = ` + INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload) + SELECT $1, $2, $3, $4 + WHERE NOT EXISTS ( + SELECT 1 + FROM scheduler_outbox + WHERE event_type = $1 + AND account_id IS NOT DISTINCT FROM $2 + AND group_id IS NOT DISTINCT FROM $3 + AND created_at >= NOW() - make_interval(secs => $5) + ) + ` + args = append(args, schedulerOutboxDedupWindow.Seconds()) + } + _, err := exec.ExecContext(ctx, query, args...) return err } + +func schedulerOutboxEventSupportsDedup(eventType string) bool { + switch eventType { + case service.SchedulerOutboxEventAccountChanged, + service.SchedulerOutboxEventGroupChanged, + service.SchedulerOutboxEventFullRebuild: + return true + default: + return false + } +} diff --git a/backend/internal/repository/setting_repo_integration_test.go b/backend/internal/repository/setting_repo_integration_test.go index 147313d6..f37b2de1 100644 --- a/backend/internal/repository/setting_repo_integration_test.go +++ b/backend/internal/repository/setting_repo_integration_test.go @@ -122,7 +122,7 @@ func (s *SettingRepoSuite) TestSet_EmptyValue() { func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() { // 模拟保存站点设置,部分字段有值,部分字段为空 settings := map[string]string{ - "site_name": "AICodex2API", + "site_name": "Sub2api", "site_subtitle": "Subscription to API", "site_logo": "", // 用户未上传Logo "api_base_url": "", // 用户未设置API地址 @@ -136,7 +136,7 @@ func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() { result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"}) s.Require().NoError(err, "GetMultiple after SetMultiple with empty values") - s.Require().Equal("AICodex2API", result["site_name"]) + s.Require().Equal("Sub2api", result["site_name"]) s.Require().Equal("Subscription to API", result["site_subtitle"]) s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved") s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved") diff --git a/backend/internal/repository/simple_mode_admin_concurrency.go b/backend/internal/repository/simple_mode_admin_concurrency.go new file mode 100644 index 00000000..4d1db150 --- /dev/null +++ b/backend/internal/repository/simple_mode_admin_concurrency.go @@ -0,0 +1,55 @@ +package repository + +import ( + "context" + "fmt" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/setting" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +const ( + simpleModeAdminConcurrencyUpgradeKey = "simple_mode_admin_concurrency_upgraded_30" + simpleModeLegacyAdminConcurrency = 5 + simpleModeTargetAdminConcurrency = 30 +) + +func ensureSimpleModeAdminConcurrency(ctx context.Context, client *dbent.Client) error { + if client == nil { + return fmt.Errorf("nil ent client") + } + + upgraded, err := client.Setting.Query().Where(setting.KeyEQ(simpleModeAdminConcurrencyUpgradeKey)).Exist(ctx) + if err != nil { + return fmt.Errorf("check admin concurrency upgrade marker: %w", err) + } + if upgraded { + return nil + } + + if _, err := client.User.Update(). + Where( + dbuser.RoleEQ(service.RoleAdmin), + dbuser.ConcurrencyEQ(simpleModeLegacyAdminConcurrency), + ). + SetConcurrency(simpleModeTargetAdminConcurrency). + Save(ctx); err != nil { + return fmt.Errorf("upgrade simple mode admin concurrency: %w", err) + } + + now := time.Now() + if err := client.Setting.Create(). + SetKey(simpleModeAdminConcurrencyUpgradeKey). + SetValue(now.Format(time.RFC3339)). + SetUpdatedAt(now). + OnConflictColumns(setting.FieldKey). + UpdateNewValues(). + Exec(ctx); err != nil { + return fmt.Errorf("persist admin concurrency upgrade marker: %w", err) + } + + return nil +} diff --git a/backend/internal/repository/soft_delete_ent_integration_test.go b/backend/internal/repository/soft_delete_ent_integration_test.go index ef63fbee..8c2b23da 100644 --- a/backend/internal/repository/soft_delete_ent_integration_test.go +++ b/backend/internal/repository/soft_delete_ent_integration_test.go @@ -41,7 +41,7 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) { u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com") - repo := NewAPIKeyRepository(client) + repo := NewAPIKeyRepository(client, integrationDB) key := &service.APIKey{ UserID: u.ID, Key: uniqueSoftDeleteValue(t, "sk-soft-delete"), @@ -73,7 +73,7 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) { u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com") - repo := NewAPIKeyRepository(client) + repo := NewAPIKeyRepository(client, integrationDB) key := &service.APIKey{ UserID: u.ID, Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"), @@ -93,7 +93,7 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) { u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com") - repo := NewAPIKeyRepository(client) + repo := NewAPIKeyRepository(client, integrationDB) key := &service.APIKey{ UserID: u.ID, Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"), diff --git a/backend/internal/repository/usage_billing_repo.go b/backend/internal/repository/usage_billing_repo.go new file mode 100644 index 00000000..b4c76da5 --- /dev/null +++ b/backend/internal/repository/usage_billing_repo.go @@ -0,0 +1,308 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type usageBillingRepository struct { + db *sql.DB +} + +func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository { + return &usageBillingRepository{db: sqlDB} +} + +func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) { + if cmd == nil { + return &service.UsageBillingApplyResult{}, nil + } + if r == nil || r.db == nil { + return nil, errors.New("usage billing repository db is nil") + } + + cmd.Normalize() + if cmd.RequestID == "" { + return nil, service.ErrUsageBillingRequestIDRequired + } + + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + if tx != nil { + _ = tx.Rollback() + } + }() + + applied, err := r.claimUsageBillingKey(ctx, tx, cmd) + if err != nil { + return nil, err + } + if !applied { + return &service.UsageBillingApplyResult{Applied: false}, nil + } + + result := &service.UsageBillingApplyResult{Applied: true} + if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + tx = nil + return result, nil +} + +func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) { + var id int64 + err := tx.QueryRowContext(ctx, ` + INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint) + VALUES ($1, $2, $3) + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING id + `, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + var existingFingerprint string + if err := tx.QueryRowContext(ctx, ` + SELECT request_fingerprint + FROM usage_billing_dedup + WHERE request_id = $1 AND api_key_id = $2 + `, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil { + return false, err + } + if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) { + return false, service.ErrUsageBillingRequestConflict + } + return false, nil + } + if err != nil { + return false, err + } + var archivedFingerprint string + err = tx.QueryRowContext(ctx, ` + SELECT request_fingerprint + FROM usage_billing_dedup_archive + WHERE request_id = $1 AND api_key_id = $2 + `, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint) + if err == nil { + if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) { + return false, service.ErrUsageBillingRequestConflict + } + return false, nil + } + if !errors.Is(err, sql.ErrNoRows) { + return false, err + } + return true, nil +} + +func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error { + if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil { + if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil { + return err + } + } + + if cmd.BalanceCost > 0 { + if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil { + return err + } + } + + if cmd.APIKeyQuotaCost > 0 { + exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost) + if err != nil { + return err + } + result.APIKeyQuotaExhausted = exhausted + } + + if cmd.APIKeyRateLimitCost > 0 { + if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil { + return err + } + } + + if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) { + if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil { + return err + } + } + + return nil +} + +func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error { + const updateSQL = ` + UPDATE user_subscriptions us + SET + daily_usage_usd = us.daily_usage_usd + $1, + weekly_usage_usd = us.weekly_usage_usd + $1, + monthly_usage_usd = us.monthly_usage_usd + $1, + updated_at = NOW() + FROM groups g + WHERE us.id = $2 + AND us.deleted_at IS NULL + AND us.group_id = g.id + AND g.deleted_at IS NULL + ` + res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected > 0 { + return nil + } + return service.ErrSubscriptionNotFound +} + +func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error { + res, err := tx.ExecContext(ctx, ` + UPDATE users + SET balance = balance - $1, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + `, amount, userID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected > 0 { + return nil + } + return service.ErrUserNotFound +} + +func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) { + var exhausted bool + err := tx.QueryRowContext(ctx, ` + UPDATE api_keys + SET quota_used = quota_used + $1, + status = CASE + WHEN quota > 0 + AND status = $3 + AND quota_used < quota + AND quota_used + $1 >= quota + THEN $4 + ELSE status + END, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota + `, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted) + if errors.Is(err, sql.ErrNoRows) { + return false, service.ErrAPIKeyNotFound + } + if err != nil { + return false, err + } + return exhausted, nil +} + +func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error { + res, err := tx.ExecContext(ctx, ` + UPDATE api_keys SET + usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END, + usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END, + usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END, + window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END, + window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END, + window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + `, cost, apiKeyID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return service.ErrAPIKeyNotFound + } + return nil +} + +func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error { + rows, err := tx.QueryContext(ctx, + `UPDATE accounts SET extra = ( + COALESCE(extra, '{}'::jsonb) + || jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1) + || CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN + jsonb_build_object( + 'quota_daily_used', + CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + THEN $1 + ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END, + 'quota_daily_start', + CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + THEN `+nowUTC+` + ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END + ) + ELSE '{}'::jsonb END + || CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN + jsonb_build_object( + 'quota_weekly_used', + CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + THEN $1 + ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END, + 'quota_weekly_start', + CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + THEN `+nowUTC+` + ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END + ) + ELSE '{}'::jsonb END + ), updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + RETURNING + COALESCE((extra->>'quota_used')::numeric, 0), + COALESCE((extra->>'quota_limit')::numeric, 0)`, + amount, accountID) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + + var newUsed, limit float64 + if rows.Next() { + if err := rows.Scan(&newUsed, &limit); err != nil { + return err + } + } else { + if err := rows.Err(); err != nil { + return err + } + return service.ErrAccountNotFound + } + if err := rows.Err(); err != nil { + return err + } + if limit > 0 && newUsed >= limit && (newUsed-amount) < limit { + if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil { + logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err) + return err + } + } + return nil +} diff --git a/backend/internal/repository/usage_billing_repo_integration_test.go b/backend/internal/repository/usage_billing_repo_integration_test.go new file mode 100644 index 00000000..eda34cc9 --- /dev/null +++ b/backend/internal/repository/usage_billing_repo_integration_test.go @@ -0,0 +1,279 @@ +//go:build integration + +package repository + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Balance: 100, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-" + uuid.NewString(), + Name: "billing", + Quota: 1, + }) + account := mustCreateAccount(t, client, &service.Account{ + Name: "usage-billing-account-" + uuid.NewString(), + Type: service.AccountTypeAPIKey, + }) + + requestID := uuid.NewString() + cmd := &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + AccountID: account.ID, + AccountType: service.AccountTypeAPIKey, + BalanceCost: 1.25, + APIKeyQuotaCost: 1.25, + APIKeyRateLimitCost: 1.25, + } + + result1, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.NotNil(t, result1) + require.True(t, result1.Applied) + require.True(t, result1.APIKeyQuotaExhausted) + + result2, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.NotNil(t, result2) + require.False(t, result2.Applied) + + var balance float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance)) + require.InDelta(t, 98.75, balance, 0.000001) + + var quotaUsed float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed)) + require.InDelta(t, 1.25, quotaUsed, 0.000001) + + var usage5h float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h)) + require.InDelta(t, 1.25, usage5h, 0.000001) + + var status string + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status)) + require.Equal(t, service.StatusAPIKeyQuotaExhausted, status) + + var dedupCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount)) + require.Equal(t, 1, dedupCount) +} + +func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + }) + group := mustCreateGroup(t, client, &service.Group{ + Name: "usage-billing-group-" + uuid.NewString(), + Platform: service.PlatformAnthropic, + SubscriptionType: service.SubscriptionTypeSubscription, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + GroupID: &group.ID, + Key: "sk-usage-billing-sub-" + uuid.NewString(), + Name: "billing-sub", + }) + subscription := mustCreateSubscription(t, client, &service.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + }) + + requestID := uuid.NewString() + cmd := &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + AccountID: 0, + SubscriptionID: &subscription.ID, + SubscriptionCost: 2.5, + } + + result1, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.True(t, result1.Applied) + + result2, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.False(t, result2.Applied) + + var dailyUsage float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage)) + require.InDelta(t, 2.5, dailyUsage, 0.000001) +} + +func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Balance: 100, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-conflict-" + uuid.NewString(), + Name: "billing-conflict", + }) + + requestID := uuid.NewString() + _, err := repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + BalanceCost: 1.25, + }) + require.NoError(t, err) + + _, err = repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + BalanceCost: 2.50, + }) + require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict) +} + +func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-account-" + uuid.NewString(), + Name: "billing-account", + }) + account := mustCreateAccount(t, client, &service.Account{ + Name: "usage-billing-account-quota-" + uuid.NewString(), + Type: service.AccountTypeAPIKey, + Extra: map[string]any{ + "quota_limit": 100.0, + }, + }) + + _, err := repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: uuid.NewString(), + APIKeyID: apiKey.ID, + UserID: user.ID, + AccountID: account.ID, + AccountType: service.AccountTypeAPIKey, + AccountQuotaCost: 3.5, + }) + require.NoError(t, err) + + var quotaUsed float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed)) + require.InDelta(t, 3.5, quotaUsed, 0.000001) +} + +func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) { + ctx := context.Background() + repo := newDashboardAggregationRepositoryWithSQL(integrationDB) + + oldRequestID := "dedup-old-" + uuid.NewString() + newRequestID := "dedup-new-" + uuid.NewString() + oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400) + newCreatedAt := time.Now().UTC().Add(-time.Hour) + + _, err := integrationDB.ExecContext(ctx, ` + INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at) + VALUES ($1, 1, $2, $3), ($4, 1, $5, $6) + `, + oldRequestID, strings.Repeat("a", 64), oldCreatedAt, + newRequestID, strings.Repeat("b", 64), newCreatedAt, + ) + require.NoError(t, err) + + require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365))) + + var oldCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount)) + require.Equal(t, 0, oldCount) + + var newCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount)) + require.Equal(t, 1, newCount) + + var archivedCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount)) + require.Equal(t, 1, archivedCount) +} + +func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Balance: 100, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-archive-" + uuid.NewString(), + Name: "billing-archive", + }) + + requestID := uuid.NewString() + cmd := &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + BalanceCost: 1.25, + } + + result1, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.True(t, result1.Applied) + + _, err = integrationDB.ExecContext(ctx, ` + UPDATE usage_billing_dedup + SET created_at = $1 + WHERE request_id = $2 AND api_key_id = $3 + `, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID) + require.NoError(t, err) + require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365))) + + result2, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.False(t, result2.Applied) + + var balance float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance)) + require.InDelta(t, 98.75, balance, 0.000001) +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index d30cc7dd..dc70812d 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -3,10 +3,14 @@ package repository import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "os" + "strconv" "strings" + "sync" + "sync/atomic" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -15,14 +19,57 @@ import ( dbgroup "github.com/Wei-Shaw/sub2api/ent/group" dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" + gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" + +var usageLogInsertArgTypes = [...]string{ + "bigint", + "bigint", + "bigint", + "text", + "text", + "bigint", + "bigint", + "integer", + "integer", + "integer", + "integer", + "integer", + "integer", + "numeric", + "numeric", + "numeric", + "numeric", + "numeric", + "numeric", + "numeric", + "numeric", + "smallint", + "smallint", + "boolean", + "boolean", + "integer", + "integer", + "text", + "text", + "integer", + "text", + "text", + "text", + "text", + "text", + "text", + "boolean", + "timestamptz", +} // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ @@ -43,15 +90,89 @@ func safeDateFormat(granularity string) string { type usageLogRepository struct { client *dbent.Client sql sqlExecutor + db *sql.DB + + createBatchOnce sync.Once + createBatchCh chan usageLogCreateRequest + bestEffortBatchOnce sync.Once + bestEffortBatchCh chan usageLogBestEffortRequest + bestEffortRecent *gocache.Cache } +const ( + usageLogCreateBatchMaxSize = 64 + usageLogCreateBatchWindow = 3 * time.Millisecond + usageLogCreateBatchQueueCap = 4096 + usageLogCreateCancelWait = 2 * time.Second + + usageLogBestEffortBatchMaxSize = 256 + usageLogBestEffortBatchWindow = 20 * time.Millisecond + usageLogBestEffortBatchQueueCap = 32768 + usageLogBestEffortRecentTTL = 30 * time.Second +) + +type usageLogCreateRequest struct { + log *service.UsageLog + prepared usageLogInsertPrepared + shared *usageLogCreateShared + resultCh chan usageLogCreateResult +} + +type usageLogCreateResult struct { + inserted bool + err error +} + +type usageLogBestEffortRequest struct { + prepared usageLogInsertPrepared + apiKeyID int64 + resultCh chan error +} + +type usageLogInsertPrepared struct { + createdAt time.Time + requestID string + rateMultiplier float64 + requestType int16 + args []any +} + +type usageLogBatchState struct { + ID int64 + CreatedAt time.Time +} + +type usageLogBatchRow struct { + RequestID string `json:"request_id"` + APIKeyID int64 `json:"api_key_id"` + ID int64 `json:"id"` + CreatedAt time.Time `json:"created_at"` + Inserted bool `json:"inserted"` +} + +type usageLogCreateShared struct { + state atomic.Int32 +} + +const ( + usageLogCreateStateQueued int32 = iota + usageLogCreateStateProcessing + usageLogCreateStateCompleted + usageLogCreateStateCanceled +) + func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository { return newUsageLogRepositoryWithSQL(client, sqlDB) } func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository { // 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。 - return &usageLogRepository{client: client, sql: sqlq} + repo := &usageLogRepository{client: client, sql: sqlq} + if db, ok := sqlq.(*sql.DB); ok { + repo.db = db + } + repo.bestEffortRecent = gocache.New(usageLogBestEffortRecentTTL, time.Minute) + return repo } // getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤) @@ -82,24 +203,72 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) return false, nil } - // 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。 - // 无事务时回退到默认的 *sql.DB 执行器。 - sqlq := r.sql if tx := dbent.TxFromContext(ctx); tx != nil { - sqlq = tx.Client() + return r.createSingle(ctx, tx.Client(), log) } - - createdAt := log.CreatedAt - if createdAt.IsZero() { - createdAt = time.Now() - } - requestID := strings.TrimSpace(log.RequestID) + if requestID == "" { + return r.createSingle(ctx, r.sql, log) + } log.RequestID = requestID + return r.createBatched(ctx, log) +} - rateMultiplier := log.RateMultiplier - log.SyncRequestTypeAndLegacyFields() - requestType := int16(log.RequestType) +func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.UsageLog) error { + if log == nil { + return nil + } + + if tx := dbent.TxFromContext(ctx); tx != nil { + _, err := r.createSingle(ctx, tx.Client(), log) + return err + } + if r.db == nil { + _, err := r.createSingle(ctx, r.sql, log) + return err + } + + r.ensureBestEffortBatcher() + if r.bestEffortBatchCh == nil { + _, err := r.createSingle(ctx, r.sql, log) + return err + } + + req := usageLogBestEffortRequest{ + prepared: prepareUsageLogInsert(log), + apiKeyID: log.APIKeyID, + resultCh: make(chan error, 1), + } + if key, ok := r.bestEffortRecentKey(req.prepared.requestID, req.apiKeyID); ok { + if _, exists := r.bestEffortRecent.Get(key); exists { + return nil + } + } + + select { + case r.bestEffortBatchCh <- req: + case <-ctx.Done(): + return service.MarkUsageLogCreateDropped(ctx.Err()) + default: + return service.MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")) + } + + select { + case err := <-req.resultCh: + return err + case <-ctx.Done(): + return service.MarkUsageLogCreateDropped(ctx.Err()) + } +} + +func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) { + prepared := prepareUsageLogInsert(log) + if sqlq == nil { + sqlq = r.sql + } + if ctx != nil && ctx.Err() != nil { + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + } query := ` INSERT INTO usage_logs ( @@ -135,7 +304,10 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) image_count, image_size, media_type, + service_tier, reasoning_effort, + inbound_endpoint, + upstream_endpoint, cache_ttl_overridden, created_at ) VALUES ( @@ -144,12 +316,799 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35 + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at ` + if err := scanSingleRow(ctx, sqlq, query, prepared.args, &log.ID, &log.CreatedAt); err != nil { + if errors.Is(err, sql.ErrNoRows) && prepared.requestID != "" { + selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2" + if err := scanSingleRow(ctx, sqlq, selectQuery, []any{prepared.requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil { + return false, err + } + log.RateMultiplier = prepared.rateMultiplier + return false, nil + } else { + return false, err + } + } + log.RateMultiplier = prepared.rateMultiplier + return true, nil +} + +func (r *usageLogRepository) createBatched(ctx context.Context, log *service.UsageLog) (bool, error) { + if r.db == nil { + return r.createSingle(ctx, r.sql, log) + } + r.ensureCreateBatcher() + if r.createBatchCh == nil { + return r.createSingle(ctx, r.sql, log) + } + + req := usageLogCreateRequest{ + log: log, + prepared: prepareUsageLogInsert(log), + shared: &usageLogCreateShared{}, + resultCh: make(chan usageLogCreateResult, 1), + } + + select { + case r.createBatchCh <- req: + case <-ctx.Done(): + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + default: + return false, service.MarkUsageLogCreateNotPersisted(errors.New("usage log create batch queue full")) + } + + select { + case res := <-req.resultCh: + return res.inserted, res.err + case <-ctx.Done(): + if req.shared != nil && req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateCanceled) { + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + } + timer := time.NewTimer(usageLogCreateCancelWait) + defer timer.Stop() + select { + case res := <-req.resultCh: + return res.inserted, res.err + case <-timer.C: + return false, ctx.Err() + } + } +} + +func (r *usageLogRepository) ensureCreateBatcher() { + if r == nil || r.db == nil || r.createBatchCh != nil { + return + } + r.createBatchOnce.Do(func() { + r.createBatchCh = make(chan usageLogCreateRequest, usageLogCreateBatchQueueCap) + go r.runCreateBatcher(r.db) + }) +} + +func (r *usageLogRepository) ensureBestEffortBatcher() { + if r == nil || r.db == nil || r.bestEffortBatchCh != nil { + return + } + r.bestEffortBatchOnce.Do(func() { + r.bestEffortBatchCh = make(chan usageLogBestEffortRequest, usageLogBestEffortBatchQueueCap) + go r.runBestEffortBatcher(r.db) + }) +} + +func (r *usageLogRepository) runCreateBatcher(db *sql.DB) { + for { + first, ok := <-r.createBatchCh + if !ok { + return + } + + batch := make([]usageLogCreateRequest, 0, usageLogCreateBatchMaxSize) + batch = append(batch, first) + + timer := time.NewTimer(usageLogCreateBatchWindow) + batchLoop: + for len(batch) < usageLogCreateBatchMaxSize { + select { + case req, ok := <-r.createBatchCh: + if !ok { + break batchLoop + } + batch = append(batch, req) + case <-timer.C: + break batchLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + + r.flushCreateBatch(db, batch) + } +} + +func (r *usageLogRepository) runBestEffortBatcher(db *sql.DB) { + for { + first, ok := <-r.bestEffortBatchCh + if !ok { + return + } + + batch := make([]usageLogBestEffortRequest, 0, usageLogBestEffortBatchMaxSize) + batch = append(batch, first) + + timer := time.NewTimer(usageLogBestEffortBatchWindow) + bestEffortLoop: + for len(batch) < usageLogBestEffortBatchMaxSize { + select { + case req, ok := <-r.bestEffortBatchCh: + if !ok { + break bestEffortLoop + } + batch = append(batch, req) + case <-timer.C: + break bestEffortLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + + r.flushBestEffortBatch(db, batch) + } +} + +func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) { + if len(batch) == 0 { + return + } + + uniqueOrder := make([]string, 0, len(batch)) + preparedByKey := make(map[string]usageLogInsertPrepared, len(batch)) + requestsByKey := make(map[string][]usageLogCreateRequest, len(batch)) + fallback := make([]usageLogCreateRequest, 0) + + for _, req := range batch { + if req.log == nil { + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + continue + } + if req.shared != nil && !req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateProcessing) { + if req.shared.state.Load() == usageLogCreateStateCanceled { + completeUsageLogCreateRequest(req, usageLogCreateResult{ + inserted: false, + err: service.MarkUsageLogCreateNotPersisted(context.Canceled), + }) + continue + } + } + prepared := req.prepared + if prepared.requestID == "" { + fallback = append(fallback, req) + continue + } + key := usageLogBatchKey(prepared.requestID, req.log.APIKeyID) + if _, exists := requestsByKey[key]; !exists { + uniqueOrder = append(uniqueOrder, key) + preparedByKey[key] = prepared + } + requestsByKey[key] = append(requestsByKey[key], req) + } + + if len(uniqueOrder) > 0 { + insertedMap, stateMap, safeFallback, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey) + if err != nil { + if safeFallback { + for _, key := range uniqueOrder { + fallback = append(fallback, requestsByKey[key]...) + } + } else { + for _, key := range uniqueOrder { + reqs := requestsByKey[key] + state, hasState := stateMap[key] + inserted := insertedMap[key] + for idx, req := range reqs { + req.log.RateMultiplier = preparedByKey[key].rateMultiplier + if hasState { + req.log.ID = state.ID + req.log.CreatedAt = state.CreatedAt + } + switch { + case inserted && idx == 0: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: true, err: nil}) + case inserted: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + case hasState: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + case idx == 0: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: err}) + default: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + } + } + } + } + } else { + for _, key := range uniqueOrder { + reqs := requestsByKey[key] + state, ok := stateMap[key] + if !ok { + for _, req := range reqs { + completeUsageLogCreateRequest(req, usageLogCreateResult{ + inserted: false, + err: fmt.Errorf("usage log batch state missing for key=%s", key), + }) + } + continue + } + for idx, req := range reqs { + req.log.ID = state.ID + req.log.CreatedAt = state.CreatedAt + req.log.RateMultiplier = preparedByKey[key].rateMultiplier + completeUsageLogCreateRequest(req, usageLogCreateResult{ + inserted: idx == 0 && insertedMap[key], + err: nil, + }) + } + } + } + } + + if len(fallback) == 0 { + return + } + + for _, req := range fallback { + fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + inserted, err := r.createSingle(fallbackCtx, db, req.log) + cancel() + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: inserted, err: err}) + } +} + +func (r *usageLogRepository) flushBestEffortBatch(db *sql.DB, batch []usageLogBestEffortRequest) { + if len(batch) == 0 { + return + } + + type bestEffortGroup struct { + prepared usageLogInsertPrepared + apiKeyID int64 + key string + reqs []usageLogBestEffortRequest + } + + groupsByKey := make(map[string]*bestEffortGroup, len(batch)) + groupOrder := make([]*bestEffortGroup, 0, len(batch)) + preparedList := make([]usageLogInsertPrepared, 0, len(batch)) + + for idx, req := range batch { + prepared := req.prepared + key := fmt.Sprintf("__best_effort_%d", idx) + if prepared.requestID != "" { + key = usageLogBatchKey(prepared.requestID, req.apiKeyID) + } + group, exists := groupsByKey[key] + if !exists { + group = &bestEffortGroup{ + prepared: prepared, + apiKeyID: req.apiKeyID, + key: key, + } + groupsByKey[key] = group + groupOrder = append(groupOrder, group) + preparedList = append(preparedList, prepared) + } + group.reqs = append(group.reqs, req) + } + + if len(preparedList) == 0 { + for _, req := range batch { + sendUsageLogBestEffortResult(req.resultCh, nil) + } + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + query, args := buildUsageLogBestEffortInsertQuery(preparedList) + if _, err := db.ExecContext(ctx, query, args...); err != nil { + logger.LegacyPrintf("repository.usage_log", "best-effort batch insert failed: %v", err) + for _, group := range groupOrder { + singleErr := execUsageLogInsertNoResult(ctx, db, group.prepared) + if singleErr != nil { + logger.LegacyPrintf("repository.usage_log", "best-effort single fallback insert failed: %v", singleErr) + } else if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil { + r.bestEffortRecent.SetDefault(group.key, struct{}{}) + } + for _, req := range group.reqs { + sendUsageLogBestEffortResult(req.resultCh, singleErr) + } + } + return + } + for _, group := range groupOrder { + if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil { + r.bestEffortRecent.SetDefault(group.key, struct{}{}) + } + for _, req := range group.reqs { + sendUsageLogBestEffortResult(req.resultCh, nil) + } + } +} + +func sendUsageLogBestEffortResult(ch chan error, err error) { + if ch == nil { + return + } + select { + case ch <- err: + default: + } +} + +func completeUsageLogCreateRequest(req usageLogCreateRequest, res usageLogCreateResult) { + if req.shared != nil { + req.shared.state.Store(usageLogCreateStateCompleted) + } + sendUsageLogCreateResult(req.resultCh, res) +} + +func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, bool, error) { + if len(keys) == 0 { + return map[string]bool{}, map[string]usageLogBatchState{}, false, nil + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey) + var payload []byte + if err := db.QueryRowContext(ctx, query, args...).Scan(&payload); err != nil { + return nil, nil, true, err + } + var rows []usageLogBatchRow + if err := json.Unmarshal(payload, &rows); err != nil { + return nil, nil, false, err + } + insertedMap := make(map[string]bool, len(keys)) + stateMap := make(map[string]usageLogBatchState, len(keys)) + for _, row := range rows { + key := usageLogBatchKey(row.RequestID, row.APIKeyID) + insertedMap[key] = row.Inserted + stateMap[key] = usageLogBatchState{ + ID: row.ID, + CreatedAt: row.CreatedAt, + } + } + if len(stateMap) != len(keys) { + return insertedMap, stateMap, false, fmt.Errorf("usage log batch state count mismatch: got=%d want=%d", len(stateMap), len(keys)) + } + return insertedMap, stateMap, false, nil +} + +func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) { + var query strings.Builder + _, _ = query.WriteString(` + WITH input ( + input_idx, + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + created_at + ) AS (VALUES `) + + args := make([]any, 0, len(keys)*38) + argPos := 1 + for idx, key := range keys { + if idx > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("(") + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + args = append(args, idx) + argPos++ + prepared := preparedByKey[key] + for i := 0; i < len(prepared.args); i++ { + _, _ = query.WriteString(",") + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + if i < len(usageLogInsertArgTypes) { + _, _ = query.WriteString("::") + _, _ = query.WriteString(usageLogInsertArgTypes[i]) + } + argPos++ + } + _, _ = query.WriteString(")") + args = append(args, prepared.args...) + } + _, _ = query.WriteString(` + ), + inserted AS ( + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + created_at + ) + SELECT + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + created_at + FROM input + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING request_id, api_key_id, id, created_at + ), + resolved AS ( + SELECT + input.input_idx, + input.request_id, + input.api_key_id, + COALESCE(inserted.id, existing.id) AS id, + COALESCE(inserted.created_at, existing.created_at) AS created_at, + (inserted.id IS NOT NULL) AS inserted + FROM input + LEFT JOIN inserted + ON inserted.request_id = input.request_id + AND inserted.api_key_id = input.api_key_id + LEFT JOIN usage_logs existing + ON existing.request_id = input.request_id + AND existing.api_key_id = input.api_key_id + ) + SELECT COALESCE( + json_agg( + json_build_object( + 'request_id', resolved.request_id, + 'api_key_id', resolved.api_key_id, + 'id', resolved.id, + 'created_at', resolved.created_at, + 'inserted', resolved.inserted + ) + ORDER BY resolved.input_idx + ), + '[]'::json + ) + FROM resolved + `) + return query.String(), args +} + +func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (string, []any) { + var query strings.Builder + _, _ = query.WriteString(` + WITH input ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + created_at + ) AS (VALUES `) + + args := make([]any, 0, len(preparedList)*38) + argPos := 1 + for idx, prepared := range preparedList { + if idx > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("(") + for i := 0; i < len(prepared.args); i++ { + if i > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + if i < len(usageLogInsertArgTypes) { + _, _ = query.WriteString("::") + _, _ = query.WriteString(usageLogInsertArgTypes[i]) + } + argPos++ + } + _, _ = query.WriteString(")") + args = append(args, prepared.args...) + } + + _, _ = query.WriteString(` + ) + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + created_at + ) + SELECT + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + created_at + FROM input + ON CONFLICT (request_id, api_key_id) DO NOTHING + `) + + return query.String(), args +} + +func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared usageLogInsertPrepared) error { + _, err := sqlq.ExecContext(ctx, ` + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + created_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, + $8, $9, $10, $11, + $12, $13, + $14, $15, $16, $17, $18, $19, + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 + ) + ON CONFLICT (request_id, api_key_id) DO NOTHING + `, prepared.args...) + return err +} + +func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { + createdAt := log.CreatedAt + if createdAt.IsZero() { + createdAt = time.Now() + } + + requestID := strings.TrimSpace(log.RequestID) + log.RequestID = requestID + + rateMultiplier := log.RateMultiplier + log.SyncRequestTypeAndLegacyFields() + requestType := int16(log.RequestType) + groupID := nullInt64(log.GroupID) subscriptionID := nullInt64(log.SubscriptionID) duration := nullInt(log.DurationMs) @@ -158,64 +1117,84 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) mediaType := nullString(log.MediaType) + serviceTier := nullString(log.ServiceTier) reasoningEffort := nullString(log.ReasoningEffort) + inboundEndpoint := nullString(log.InboundEndpoint) + upstreamEndpoint := nullString(log.UpstreamEndpoint) var requestIDArg any if requestID != "" { requestIDArg = requestID } - args := []any{ - log.UserID, - log.APIKeyID, - log.AccountID, - requestIDArg, - log.Model, - groupID, - subscriptionID, - log.InputTokens, - log.OutputTokens, - log.CacheCreationTokens, - log.CacheReadTokens, - log.CacheCreation5mTokens, - log.CacheCreation1hTokens, - log.InputCost, - log.OutputCost, - log.CacheCreationCost, - log.CacheReadCost, - log.TotalCost, - log.ActualCost, - rateMultiplier, - log.AccountRateMultiplier, - log.BillingType, - requestType, - log.Stream, - log.OpenAIWSMode, - duration, - firstToken, - userAgent, - ipAddress, - log.ImageCount, - imageSize, - mediaType, - reasoningEffort, - log.CacheTTLOverridden, - createdAt, + return usageLogInsertPrepared{ + createdAt: createdAt, + requestID: requestID, + rateMultiplier: rateMultiplier, + requestType: requestType, + args: []any{ + log.UserID, + log.APIKeyID, + log.AccountID, + requestIDArg, + log.Model, + groupID, + subscriptionID, + log.InputTokens, + log.OutputTokens, + log.CacheCreationTokens, + log.CacheReadTokens, + log.CacheCreation5mTokens, + log.CacheCreation1hTokens, + log.InputCost, + log.OutputCost, + log.CacheCreationCost, + log.CacheReadCost, + log.TotalCost, + log.ActualCost, + rateMultiplier, + log.AccountRateMultiplier, + log.BillingType, + requestType, + log.Stream, + log.OpenAIWSMode, + duration, + firstToken, + userAgent, + ipAddress, + log.ImageCount, + imageSize, + mediaType, + serviceTier, + reasoningEffort, + inboundEndpoint, + upstreamEndpoint, + log.CacheTTLOverridden, + createdAt, + }, } - if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { - if errors.Is(err, sql.ErrNoRows) && requestID != "" { - selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2" - if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil { - return false, err - } - log.RateMultiplier = rateMultiplier - return false, nil - } else { - return false, err - } +} + +func usageLogBatchKey(requestID string, apiKeyID int64) string { + return requestID + "\x1f" + strconv.FormatInt(apiKeyID, 10) +} + +func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateResult) { + if ch == nil { + return } - log.RateMultiplier = rateMultiplier - return true, nil + select { + case ch <- res: + default: + } +} + +func (r *usageLogRepository) bestEffortRecentKey(requestID string, apiKeyID int64) (string, bool) { + requestID = strings.TrimSpace(requestID) + if requestID == "" || r == nil || r.bestEffortRecent == nil { + return "", false + } + return usageLogBatchKey(requestID, apiKeyID), true } func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { @@ -1036,6 +2015,10 @@ type ModelStat = usagestats.ModelStat // UserUsageTrendPoint represents user usage trend data point type UserUsageTrendPoint = usagestats.UserUsageTrendPoint +// UserSpendingRankingItem represents a user spending ranking row. +type UserSpendingRankingItem = usagestats.UserSpendingRankingItem +type UserSpendingRankingResponse = usagestats.UserSpendingRankingResponse + // APIKeyUsageTrendPoint represents API key usage trend data point type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint @@ -1111,6 +2094,7 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e TO_CHAR(u.created_at, '%s') as date, u.user_id, COALESCE(us.email, '') as email, + COALESCE(us.username, '') as username, COUNT(*) as requests, COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens, COALESCE(SUM(u.total_cost), 0) as cost, @@ -1119,7 +2103,7 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e LEFT JOIN users us ON u.user_id = us.id WHERE u.user_id IN (SELECT user_id FROM top_users) AND u.created_at >= $4 AND u.created_at < $5 - GROUP BY date, u.user_id, us.email + GROUP BY date, u.user_id, us.email, us.username ORDER BY date ASC, tokens DESC `, dateFormat) @@ -1139,7 +2123,7 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e results = make([]UserUsageTrendPoint, 0) for rows.Next() { var row UserUsageTrendPoint - if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil { + if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Username, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil { return nil, err } results = append(results, row) @@ -1151,6 +2135,86 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e return results, nil } +// GetUserSpendingRanking returns user spending ranking aggregated within the time range. +func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (result *UserSpendingRankingResponse, err error) { + if limit <= 0 { + limit = 12 + } + + query := ` + WITH user_spend AS ( + SELECT + u.user_id, + COALESCE(us.email, '') as email, + COALESCE(SUM(u.actual_cost), 0) as actual_cost, + COUNT(*) as requests, + COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens + FROM usage_logs u + LEFT JOIN users us ON u.user_id = us.id + WHERE u.created_at >= $1 AND u.created_at < $2 + GROUP BY u.user_id, us.email + ), + ranked AS ( + SELECT + user_id, + email, + actual_cost, + requests, + tokens, + COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost, + COALESCE(SUM(requests) OVER (), 0) as total_requests, + COALESCE(SUM(tokens) OVER (), 0) as total_tokens + FROM user_spend + ORDER BY actual_cost DESC, tokens DESC, user_id ASC + LIMIT $3 + ) + SELECT + user_id, + email, + actual_cost, + requests, + tokens, + total_actual_cost, + total_requests, + total_tokens + FROM ranked + ORDER BY actual_cost DESC, tokens DESC, user_id ASC + ` + + rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + result = nil + } + }() + + ranking := make([]UserSpendingRankingItem, 0) + totalActualCost := 0.0 + totalRequests := int64(0) + totalTokens := int64(0) + for rows.Next() { + var row UserSpendingRankingItem + if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost, &totalRequests, &totalTokens); err != nil { + return nil, err + } + ranking = append(ranking, row) + } + if err = rows.Err(); err != nil { + return nil, err + } + + return &UserSpendingRankingResponse{ + Ranking: ranking, + TotalActualCost: totalActualCost, + TotalRequests: totalRequests, + TotalTokens: totalTokens, + }, nil +} + // UserDashboardStats 用户仪表盘统计 type UserDashboardStats = usagestats.UserDashboardStats @@ -1363,7 +2427,8 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, - COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(actual_cost), 0) as actual_cost @@ -1401,6 +2466,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64 COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(actual_cost), 0) as actual_cost @@ -1468,12 +2535,21 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat args = append(args, *filters.StartTime) } if filters.EndTime != nil { - conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1)) + conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1)) args = append(args, *filters.EndTime) } whereClause := buildWhere(conditions) - logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params) + var ( + logs []service.UsageLog + page *pagination.PaginationResult + err error + ) + if shouldUseFastUsageLogTotal(filters) { + logs, page, err = r.listUsageLogsWithFastPagination(ctx, whereClause, args, params) + } else { + logs, page, err = r.listUsageLogsWithPagination(ctx, whereClause, args, params) + } if err != nil { return nil, nil, err } @@ -1484,17 +2560,45 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat return logs, page, nil } +func shouldUseFastUsageLogTotal(filters UsageLogFilters) bool { + if filters.ExactTotal { + return false + } + // 强选择过滤下记录集通常较小,保留精确总数。 + return filters.UserID == 0 && filters.APIKeyID == 0 && filters.AccountID == 0 +} + // UsageStats represents usage statistics type UsageStats = usagestats.UsageStats // BatchUserUsageStats represents usage stats for a single user type BatchUserUsageStats = usagestats.BatchUserUsageStats +func normalizePositiveInt64IDs(ids []int64) []int64 { + if len(ids) == 0 { + return nil + } + seen := make(map[int64]struct{}, len(ids)) + out := make([]int64, 0, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + return out +} + // GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range. // If startTime is zero, defaults to 30 days ago. func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) { result := make(map[int64]*BatchUserUsageStats) - if len(userIDs) == 0 { + normalizedUserIDs := normalizePositiveInt64IDs(userIDs) + if len(normalizedUserIDs) == 0 { return result, nil } @@ -1506,58 +2610,36 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs endTime = time.Now() } - for _, id := range userIDs { + for _, id := range normalizedUserIDs { result[id] = &BatchUserUsageStats{UserID: id} } query := ` - SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost + SELECT + user_id, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost FROM usage_logs - WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3 + WHERE user_id = ANY($1) + AND created_at >= LEAST($2, $4) GROUP BY user_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime) + today := timezone.Today() + rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today) if err != nil { return nil, err } for rows.Next() { var userID int64 var total float64 - if err := rows.Scan(&userID, &total); err != nil { + var todayTotal float64 + if err := rows.Scan(&userID, &total, &todayTotal); err != nil { _ = rows.Close() return nil, err } if stats, ok := result[userID]; ok { stats.TotalActualCost = total - } - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - - today := timezone.Today() - todayQuery := ` - SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost - FROM usage_logs - WHERE user_id = ANY($1) AND created_at >= $2 - GROUP BY user_id - ` - rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today) - if err != nil { - return nil, err - } - for rows.Next() { - var userID int64 - var total float64 - if err := rows.Scan(&userID, &total); err != nil { - _ = rows.Close() - return nil, err - } - if stats, ok := result[userID]; ok { - stats.TodayActualCost = total + stats.TodayActualCost = todayTotal } } if err := rows.Close(); err != nil { @@ -1577,7 +2659,8 @@ type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats // If startTime is zero, defaults to 30 days ago. func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) { result := make(map[int64]*BatchAPIKeyUsageStats) - if len(apiKeyIDs) == 0 { + normalizedAPIKeyIDs := normalizePositiveInt64IDs(apiKeyIDs) + if len(normalizedAPIKeyIDs) == 0 { return result, nil } @@ -1589,58 +2672,36 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe endTime = time.Now() } - for _, id := range apiKeyIDs { + for _, id := range normalizedAPIKeyIDs { result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} } query := ` - SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost + SELECT + api_key_id, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost FROM usage_logs - WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3 + WHERE api_key_id = ANY($1) + AND created_at >= LEAST($2, $4) GROUP BY api_key_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime) + today := timezone.Today() + rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedAPIKeyIDs), startTime, endTime, today) if err != nil { return nil, err } for rows.Next() { var apiKeyID int64 var total float64 - if err := rows.Scan(&apiKeyID, &total); err != nil { + var todayTotal float64 + if err := rows.Scan(&apiKeyID, &total, &todayTotal); err != nil { _ = rows.Close() return nil, err } if stats, ok := result[apiKeyID]; ok { stats.TotalActualCost = total - } - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - - today := timezone.Today() - todayQuery := ` - SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost - FROM usage_logs - WHERE api_key_id = ANY($1) AND created_at >= $2 - GROUP BY api_key_id - ` - rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today) - if err != nil { - return nil, err - } - for rows.Next() { - var apiKeyID int64 - var total float64 - if err := rows.Scan(&apiKeyID, &total); err != nil { - _ = rows.Close() - return nil, err - } - if stats, ok := result[apiKeyID]; ok { - stats.TodayActualCost = total + stats.TodayActualCost = todayTotal } } if err := rows.Close(); err != nil { @@ -1655,6 +2716,13 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe // GetUsageTrendWithFilters returns usage trend data with optional filters func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { + if shouldUsePreaggregatedTrend(granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) { + aggregated, aggregatedErr := r.getUsageTrendFromAggregates(ctx, startTime, endTime, granularity) + if aggregatedErr == nil && len(aggregated) > 0 { + return aggregated, nil + } + } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` @@ -1663,7 +2731,8 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, - COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(actual_cost), 0) as actual_cost @@ -1719,6 +2788,80 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start return results, nil } +func shouldUsePreaggregatedTrend(granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) bool { + if granularity != "day" && granularity != "hour" { + return false + } + return userID == 0 && + apiKeyID == 0 && + accountID == 0 && + groupID == 0 && + model == "" && + requestType == nil && + stream == nil && + billingType == nil +} + +func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { + dateFormat := safeDateFormat(granularity) + query := "" + args := []any{startTime, endTime} + + switch granularity { + case "hour": + query = fmt.Sprintf(` + SELECT + TO_CHAR(bucket_start, '%s') as date, + total_requests as requests, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens, + total_cost as cost, + actual_cost + FROM usage_dashboard_hourly + WHERE bucket_start >= $1 AND bucket_start < $2 + ORDER BY bucket_start ASC + `, dateFormat) + case "day": + query = fmt.Sprintf(` + SELECT + TO_CHAR(bucket_date::timestamp, '%s') as date, + total_requests as requests, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens, + total_cost as cost, + actual_cost + FROM usage_dashboard_daily + WHERE bucket_date >= $1::date AND bucket_date < $2::date + ORDER BY bucket_date ASC + `, dateFormat) + default: + return nil, nil + } + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results, err = scanTrendRows(rows) + if err != nil { + return nil, err + } + return results, nil +} + // GetModelStatsWithFilters returns model statistics with optional filters func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" @@ -1733,6 +2876,8 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, %s @@ -1867,7 +3012,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT COALESCE(SUM(actual_cost), 0) as total_actual_cost, COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs - WHERE created_at >= $1 AND created_at <= $2 + WHERE created_at >= $1 AND created_at < $2 ` stats := &UsageStats{} @@ -1925,7 +3070,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us args = append(args, *filters.StartTime) } if filters.EndTime != nil { - conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1)) + conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1)) args = append(args, *filters.EndTime) } @@ -1965,6 +3110,35 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us stats.TotalAccountCost = &totalAccountCost } stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens + + start := time.Unix(0, 0).UTC() + if filters.StartTime != nil { + start = *filters.StartTime + } + end := time.Now().UTC() + if filters.EndTime != nil { + end = *filters.EndTime + } + + endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType) + if endpointErr != nil { + logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetStatsWithFilters: %v", endpointErr) + endpoints = []EndpointStat{} + } + upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType) + if upstreamEndpointErr != nil { + logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetStatsWithFilters: %v", upstreamEndpointErr) + upstreamEndpoints = []EndpointStat{} + } + endpointPaths, endpointPathErr := r.getEndpointPathStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType) + if endpointPathErr != nil { + logger.LegacyPrintf("repository.usage_log", "getEndpointPathStatsWithFilters failed in GetStatsWithFilters: %v", endpointPathErr) + endpointPaths = []EndpointStat{} + } + stats.Endpoints = endpoints + stats.UpstreamEndpoints = upstreamEndpoints + stats.EndpointPaths = endpointPaths + return stats, nil } @@ -1977,6 +3151,163 @@ type AccountUsageSummary = usagestats.AccountUsageSummary // AccountUsageStatsResponse represents the full usage statistics response for an account type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse +// EndpointStat represents endpoint usage statistics row. +type EndpointStat = usagestats.EndpointStat + +func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) { + actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" + if accountID > 0 && userID == 0 && apiKeyID == 0 { + actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" + } + + query := fmt.Sprintf(` + SELECT + COALESCE(NULLIF(TRIM(%s), ''), 'unknown') AS endpoint, + COUNT(*) AS requests, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens, + COALESCE(SUM(total_cost), 0) as cost, + %s + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + `, endpointColumn, actualCostExpr) + + args := []any{startTime, endTime} + if userID > 0 { + query += fmt.Sprintf(" AND user_id = $%d", len(args)+1) + args = append(args, userID) + } + if apiKeyID > 0 { + query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1) + args = append(args, apiKeyID) + } + if accountID > 0 { + query += fmt.Sprintf(" AND account_id = $%d", len(args)+1) + args = append(args, accountID) + } + if groupID > 0 { + query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) + args = append(args, groupID) + } + if model != "" { + query += fmt.Sprintf(" AND model = $%d", len(args)+1) + args = append(args, model) + } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) + if billingType != nil { + query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) + args = append(args, int16(*billingType)) + } + query += " GROUP BY endpoint ORDER BY requests DESC" + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results = make([]EndpointStat, 0) + for rows.Next() { + var row EndpointStat + if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil { + return nil, err + } + results = append(results, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} + +func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) { + actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" + if accountID > 0 && userID == 0 && apiKeyID == 0 { + actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" + } + + query := fmt.Sprintf(` + SELECT + CONCAT( + COALESCE(NULLIF(TRIM(inbound_endpoint), ''), 'unknown'), + ' -> ', + COALESCE(NULLIF(TRIM(upstream_endpoint), ''), 'unknown') + ) AS endpoint, + COUNT(*) AS requests, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens, + COALESCE(SUM(total_cost), 0) as cost, + %s + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + `, actualCostExpr) + + args := []any{startTime, endTime} + if userID > 0 { + query += fmt.Sprintf(" AND user_id = $%d", len(args)+1) + args = append(args, userID) + } + if apiKeyID > 0 { + query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1) + args = append(args, apiKeyID) + } + if accountID > 0 { + query += fmt.Sprintf(" AND account_id = $%d", len(args)+1) + args = append(args, accountID) + } + if groupID > 0 { + query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) + args = append(args, groupID) + } + if model != "" { + query += fmt.Sprintf(" AND model = $%d", len(args)+1) + args = append(args, model) + } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) + if billingType != nil { + query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) + args = append(args, int16(*billingType)) + } + query += " GROUP BY endpoint ORDER BY requests DESC" + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results = make([]EndpointStat, 0) + for rows.Next() { + var row EndpointStat + if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil { + return nil, err + } + results = append(results, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} + +// GetEndpointStatsWithFilters returns inbound endpoint statistics with optional filters. +func (r *usageLogRepository) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) { + return r.getEndpointStatsByColumnWithFilters(ctx, "inbound_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) +} + +// GetUpstreamEndpointStatsWithFilters returns upstream endpoint statistics with optional filters. +func (r *usageLogRepository) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) { + return r.getEndpointStatsByColumnWithFilters(ctx, "upstream_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) +} + // GetAccountUsageStats returns comprehensive usage statistics for an account over a time range func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) { daysCount := int(endTime.Sub(startTime).Hours()/24) + 1 @@ -2139,11 +3470,23 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID if err != nil { models = []ModelStat{} } + endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil) + if endpointErr != nil { + logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetAccountUsageStats: %v", endpointErr) + endpoints = []EndpointStat{} + } + upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil) + if upstreamEndpointErr != nil { + logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetAccountUsageStats: %v", upstreamEndpointErr) + upstreamEndpoints = []EndpointStat{} + } resp = &AccountUsageStatsResponse{ - History: history, - Summary: summary, - Models: models, + History: history, + Summary: summary, + Models: models, + Endpoints: endpoints, + UpstreamEndpoints: upstreamEndpoints, } return resp, nil } @@ -2166,6 +3509,35 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh return logs, paginationResultFromTotal(total, params), nil } +func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + limit := params.Limit() + offset := params.Offset() + + limitPos := len(args) + 1 + offsetPos := len(args) + 2 + listArgs := append(append([]any{}, args...), limit+1, offset) + query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos) + + logs, err := r.queryUsageLogs(ctx, query, listArgs...) + if err != nil { + return nil, nil, err + } + + hasMore := false + if len(logs) > limit { + hasMore = true + logs = logs[:limit] + } + + total := int64(offset) + int64(len(logs)) + if hasMore { + // 只保证“还有下一页”,避免对超大表做全量 COUNT(*)。 + total = int64(offset) + int64(limit) + 1 + } + + return logs, paginationResultFromTotal(total, params), nil +} + func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) { rows, err := r.sql.QueryContext(ctx, query, args...) if err != nil { @@ -2395,7 +3767,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e imageCount int imageSize sql.NullString mediaType sql.NullString + serviceTier sql.NullString reasoningEffort sql.NullString + inboundEndpoint sql.NullString + upstreamEndpoint sql.NullString cacheTTLOverridden bool createdAt time.Time ) @@ -2434,7 +3809,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &imageCount, &imageSize, &mediaType, + &serviceTier, &reasoningEffort, + &inboundEndpoint, + &upstreamEndpoint, &cacheTTLOverridden, &createdAt, ); err != nil { @@ -2504,9 +3882,18 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if mediaType.Valid { log.MediaType = &mediaType.String } + if serviceTier.Valid { + log.ServiceTier = &serviceTier.String + } if reasoningEffort.Valid { log.ReasoningEffort = &reasoningEffort.String } + if inboundEndpoint.Valid { + log.InboundEndpoint = &inboundEndpoint.String + } + if upstreamEndpoint.Valid { + log.UpstreamEndpoint = &upstreamEndpoint.String + } return log, nil } @@ -2520,7 +3907,8 @@ func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) { &row.Requests, &row.InputTokens, &row.OutputTokens, - &row.CacheTokens, + &row.CacheCreationTokens, + &row.CacheReadTokens, &row.TotalTokens, &row.Cost, &row.ActualCost, @@ -2544,6 +3932,8 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) { &row.Requests, &row.InputTokens, &row.OutputTokens, + &row.CacheCreationTokens, + &row.CacheReadTokens, &row.TotalTokens, &row.Cost, &row.ActualCost, diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 4d50f7de..0383f3bc 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -4,6 +4,8 @@ package repository import ( "context" + "fmt" + "sync" "testing" "time" @@ -14,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -84,6 +87,367 @@ func (s *UsageLogRepoSuite) TestCreate() { s.Require().NotZero(log.ID) } +func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()}) + + const total = 16 + results := make([]bool, total) + errs := make([]error, total) + logs := make([]*service.UsageLog, total) + + var wg sync.WaitGroup + wg.Add(total) + for i := 0; i < total; i++ { + i := i + logs[i] = &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10 + i, + OutputTokens: 20 + i, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + go func() { + defer wg.Done() + results[i], errs[i] = repo.Create(ctx, logs[i]) + }() + } + wg.Wait() + + for i := 0; i < total; i++ { + require.NoError(t, errs[i]) + require.True(t, results[i]) + require.NotZero(t, logs[i].ID) + } + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count)) + require.Equal(t, total, count) +} + +func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()}) + requestID := uuid.NewString() + + log1 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + log2 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + + inserted1, err1 := repo.Create(ctx, log1) + inserted2, err2 := repo.Create(ctx, log2) + require.NoError(t, err1) + require.NoError(t, err2) + require.True(t, inserted1) + require.False(t, inserted2) + require.Equal(t, log1.ID, log2.ID) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)) + require.Equal(t, 1, count) +} + +func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()}) + requestID := uuid.NewString() + + const total = 8 + batch := make([]usageLogCreateRequest, 0, total) + logs := make([]*service.UsageLog, 0, total) + + for i := 0; i < total; i++ { + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10 + i, + OutputTokens: 20 + i, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + logs = append(logs, log) + batch = append(batch, usageLogCreateRequest{ + log: log, + prepared: prepareUsageLogInsert(log), + resultCh: make(chan usageLogCreateResult, 1), + }) + } + + repo.flushCreateBatch(integrationDB, batch) + + insertedCount := 0 + var firstID int64 + for idx, req := range batch { + res := <-req.resultCh + require.NoError(t, res.err) + if res.inserted { + insertedCount++ + } + require.NotZero(t, logs[idx].ID) + if idx == 0 { + firstID = logs[idx].ID + } else { + require.Equal(t, firstID, logs[idx].ID) + } + } + + require.Equal(t, 1, insertedCount) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)) + require.Equal(t, 1, count) +} + +func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()}) + requestID := uuid.NewString() + + log1 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + log2 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + + require.NoError(t, repo.CreateBestEffort(ctx, log1)) + require.NoError(t, repo.CreateBestEffort(ctx, log2)) + + require.Eventually(t, func() bool { + var count int + err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count) + return err == nil && count == 1 + }, 3*time.Second, 20*time.Millisecond) +} + +func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1) + repo.bestEffortBatchCh <- usageLogBestEffortRequest{} + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()}) + + err := repo.CreateBestEffort(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + + require.Error(t, err) + require.True(t, service.IsUsageLogCreateDropped(err)) +} + +func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) { + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + inserted, err := repo.Create(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + + require.False(t, inserted) + require.Error(t, err) + require.True(t, service.IsUsageLogCreateNotPersisted(err)) +} + +func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + repo.createBatchCh = make(chan usageLogCreateRequest, 1) + repo.createBatchCh <- usageLogCreateRequest{} + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()}) + + inserted, err := repo.Create(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + + require.False(t, inserted) + require.Error(t, err) + require.True(t, service.IsUsageLogCreateNotPersisted(err)) +} + +func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) { + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + repo.createBatchCh = make(chan usageLogCreateRequest, 1) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()}) + + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + + go func() { + _, err := repo.createBatched(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + errCh <- err + }() + + req := <-repo.createBatchCh + require.NotNil(t, req.shared) + cancel() + + err := <-errCh + require.Error(t, err) + require.True(t, service.IsUsageLogCreateNotPersisted(err)) + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)}) +} + +func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) { + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()}) + + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + req := usageLogCreateRequest{ + log: log, + prepared: prepareUsageLogInsert(log), + shared: &usageLogCreateShared{}, + resultCh: make(chan usageLogCreateResult, 1), + } + req.shared.state.Store(usageLogCreateStateCanceled) + + repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req}) + + res := <-req.resultCh + require.False(t, res.inserted) + require.Error(t, res.err) + require.True(t, service.IsUsageLogCreateNotPersisted(res.err)) +} + func (s *UsageLogRepoSuite) TestGetByID() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 95cf2a2d..27ae4571 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -71,7 +71,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { log.ImageCount, sqlmock.AnyArg(), // image_size sqlmock.AnyArg(), // media_type + sqlmock.AnyArg(), // service_tier sqlmock.AnyArg(), // reasoning_effort + sqlmock.AnyArg(), // inbound_endpoint + sqlmock.AnyArg(), // upstream_endpoint log.CacheTTLOverridden, createdAt, ). @@ -81,12 +84,78 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { require.NoError(t, err) require.True(t, inserted) require.Equal(t, int64(99), log.ID) + require.Nil(t, log.ServiceTier) require.Equal(t, service.RequestTypeWSV2, log.RequestType) require.True(t, log.Stream) require.True(t, log.OpenAIWSMode) require.NoError(t, mock.ExpectationsWereMet()) } +func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC) + serviceTier := "priority" + log := &service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-service-tier", + Model: "gpt-5.4", + ServiceTier: &serviceTier, + CreatedAt: createdAt, + } + + mock.ExpectQuery("INSERT INTO usage_logs"). + WithArgs( + log.UserID, + log.APIKeyID, + log.AccountID, + log.RequestID, + log.Model, + sqlmock.AnyArg(), + sqlmock.AnyArg(), + log.InputTokens, + log.OutputTokens, + log.CacheCreationTokens, + log.CacheReadTokens, + log.CacheCreation5mTokens, + log.CacheCreation1hTokens, + log.InputCost, + log.OutputCost, + log.CacheCreationCost, + log.CacheReadCost, + log.TotalCost, + log.ActualCost, + log.RateMultiplier, + log.AccountRateMultiplier, + log.BillingType, + int16(service.RequestTypeSync), + false, + false, + sqlmock.AnyArg(), + sqlmock.AnyArg(), + sqlmock.AnyArg(), + sqlmock.AnyArg(), + log.ImageCount, + sqlmock.AnyArg(), + sqlmock.AnyArg(), + serviceTier, + sqlmock.AnyArg(), + sqlmock.AnyArg(), + sqlmock.AnyArg(), + log.CacheTTLOverridden, + createdAt, + ). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt)) + + inserted, err := repo.Create(context.Background(), log) + require.NoError(t, err) + require.True(t, inserted) + require.NoError(t, mock.ExpectationsWereMet()) +} + func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { db, mock := newSQLMock(t) repo := &usageLogRepository{sql: db} @@ -96,6 +165,7 @@ func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { filters := usagestats.UsageLogFilters{ RequestType: &requestType, Stream: &stream, + ExactTotal: true, } mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). @@ -124,7 +194,7 @@ func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testin mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)"). WithArgs(start, end, requestType). - WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"})) + WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"})) trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil) require.NoError(t, err) @@ -143,7 +213,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). WithArgs(start, end, requestType). - WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"})) + WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"})) stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil) require.NoError(t, err) @@ -182,6 +252,37 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) require.NoError(t, mock.ExpectationsWereMet()) } +func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + + rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost", "total_requests", "total_tokens"}). + AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0, int64(30), int64(2600)). + AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0, int64(30), int64(2600)). + AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0, int64(30), int64(2600)) + + mock.ExpectQuery("WITH user_spend AS \\("). + WithArgs(start, end, 12). + WillReturnRows(rows) + + got, err := repo.GetUserSpendingRanking(context.Background(), start, end, 12) + require.NoError(t, err) + require.Equal(t, &usagestats.UserSpendingRankingResponse{ + Ranking: []usagestats.UserSpendingRankingItem{ + {UserID: 2, Email: "beta@example.com", ActualCost: 12.5, Requests: 9, Tokens: 900}, + {UserID: 1, Email: "alpha@example.com", ActualCost: 12.5, Requests: 8, Tokens: 800}, + {UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300}, + }, + TotalActualCost: 40.0, + TotalRequests: 30, + TotalTokens: 2600, + }, got) + require.NoError(t, mock.ExpectationsWereMet()) +} + func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) { tests := []struct { name string @@ -279,11 +380,16 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { 0, sql.NullString{}, sql.NullString{}, + sql.NullString{Valid: true, String: "priority"}, + sql.NullString{}, + sql.NullString{}, sql.NullString{}, false, now, }}) require.NoError(t, err) + require.NotNil(t, log.ServiceTier) + require.Equal(t, "priority", *log.ServiceTier) require.Equal(t, service.RequestTypeWSV2, log.RequestType) require.True(t, log.Stream) require.True(t, log.OpenAIWSMode) @@ -315,13 +421,57 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { 0, sql.NullString{}, sql.NullString{}, + sql.NullString{Valid: true, String: "flex"}, + sql.NullString{}, + sql.NullString{}, sql.NullString{}, false, now, }}) require.NoError(t, err) + require.NotNil(t, log.ServiceTier) + require.Equal(t, "flex", *log.ServiceTier) require.Equal(t, service.RequestTypeStream, log.RequestType) require.True(t, log.Stream) require.False(t, log.OpenAIWSMode) }) + + t.Run("service_tier_is_scanned", func(t *testing.T) { + now := time.Now().UTC() + log, err := scanUsageLog(usageLogScannerStub{values: []any{ + int64(3), + int64(12), + int64(22), + int64(32), + sql.NullString{Valid: true, String: "req-3"}, + "gpt-5.4", + sql.NullInt64{}, + sql.NullInt64{}, + 1, 2, 3, 4, 5, 6, + 0.1, 0.2, 0.3, 0.4, 1.0, 0.9, + 1.0, + sql.NullFloat64{}, + int16(service.BillingTypeBalance), + int16(service.RequestTypeSync), + false, + false, + sql.NullInt64{}, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + 0, + sql.NullString{}, + sql.NullString{}, + sql.NullString{Valid: true, String: "priority"}, + sql.NullString{}, + sql.NullString{}, + sql.NullString{}, + false, + now, + }}) + require.NoError(t, err) + require.NotNil(t, log.ServiceTier) + require.Equal(t, "priority", *log.ServiceTier) + }) + } diff --git a/backend/internal/repository/usage_log_repo_unit_test.go b/backend/internal/repository/usage_log_repo_unit_test.go index d0e14ffd..0458902d 100644 --- a/backend/internal/repository/usage_log_repo_unit_test.go +++ b/backend/internal/repository/usage_log_repo_unit_test.go @@ -3,8 +3,11 @@ package repository import ( + "strings" "testing" + "time" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" ) @@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) { }) } } + +func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) { + log := &service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-batch-no-update", + Model: "gpt-5", + InputTokens: 10, + OutputTokens: 5, + TotalCost: 1.2, + ActualCost: 1.2, + CreatedAt: time.Now().UTC(), + } + prepared := prepareUsageLogInsert(log) + + query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{ + usageLogBatchKey(log.RequestID, log.APIKeyID): prepared, + }) + + require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING") + require.NotContains(t, strings.ToUpper(query), "DO UPDATE") +} diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go index e3b11096..e2471ae5 100644 --- a/backend/internal/repository/user_group_rate_repo.go +++ b/backend/internal/repository/user_group_rate_repo.go @@ -95,6 +95,35 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in return result, nil } +// GetByGroupID 获取指定分组下所有用户的专属倍率 +func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) { + query := ` + SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier + FROM user_group_rate_multipliers ugr + JOIN users u ON u.id = ugr.user_id + WHERE ugr.group_id = $1 + ORDER BY ugr.user_id + ` + rows, err := r.sql.QueryContext(ctx, query, groupID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var result []service.UserGroupRateEntry + for rows.Next() { + var entry service.UserGroupRateEntry + if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil { + return nil, err + } + result = append(result, entry) + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + // GetByUserAndGroup 获取用户在特定分组的专属倍率 func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` @@ -164,6 +193,31 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID return nil } +// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插) +func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error { + if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil { + return err + } + if len(entries) == 0 { + return nil + } + userIDs := make([]int64, len(entries)) + rates := make([]float64, len(entries)) + for i, e := range entries { + userIDs[i] = e.UserID + rates[i] = e.RateMultiplier + } + now := time.Now() + _, err := r.sql.ExecContext(ctx, ` + INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) + SELECT data.user_id, $1::bigint, data.rate_multiplier, $2::timestamptz, $2::timestamptz + FROM unnest($3::bigint[], $4::double precision[]) AS data(user_id, rate_multiplier) + ON CONFLICT (user_id, group_id) + DO UPDATE SET rate_multiplier = EXCLUDED.rate_multiplier, updated_at = EXCLUDED.updated_at + `, groupID, now, pq.Array(userIDs), pq.Array(rates)) + return err +} + // DeleteByGroupID 删除指定分组的所有用户专属倍率 func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error { _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID) diff --git a/backend/internal/repository/user_msg_queue_cache.go b/backend/internal/repository/user_msg_queue_cache.go new file mode 100644 index 00000000..bb3ee698 --- /dev/null +++ b/backend/internal/repository/user_msg_queue_cache.go @@ -0,0 +1,186 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +// Redis Key 模式(使用 hash tag 确保 Redis Cluster 下同一 accountID 的 key 落入同一 slot) +// 格式: umq:{accountID}:lock / umq:{accountID}:last +const ( + umqKeyPrefix = "umq:" + umqLockSuffix = ":lock" // STRING (requestID), PX lockTtlMs + umqLastSuffix = ":last" // STRING (毫秒时间戳), EX 60s +) + +// Lua 脚本:原子获取串行锁(SET NX PX + 重入安全) +var acquireLockScript = redis.NewScript(` +local cur = redis.call('GET', KEYS[1]) +if cur == ARGV[1] then + redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[2])) + return 1 +end +if cur ~= false then return 0 end +redis.call('SET', KEYS[1], ARGV[1], 'PX', tonumber(ARGV[2])) +return 1 +`) + +// Lua 脚本:原子释放锁 + 记录完成时间(使用 Redis TIME 避免时钟偏差) +var releaseLockScript = redis.NewScript(` +local cur = redis.call('GET', KEYS[1]) +if cur == ARGV[1] then + redis.call('DEL', KEYS[1]) + local t = redis.call('TIME') + local ms = tonumber(t[1])*1000 + math.floor(tonumber(t[2])/1000) + redis.call('SET', KEYS[2], ms, 'EX', 60) + return 1 +end +return 0 +`) + +// Lua 脚本:原子清理孤儿锁(仅在 PTTL == -1 时删除,避免 TOCTOU 竞态误删合法锁) +var forceReleaseLockScript = redis.NewScript(` +local pttl = redis.call('PTTL', KEYS[1]) +if pttl == -1 then + redis.call('DEL', KEYS[1]) + return 1 +end +return 0 +`) + +type userMsgQueueCache struct { + rdb *redis.Client +} + +// NewUserMsgQueueCache 创建用户消息队列缓存 +func NewUserMsgQueueCache(rdb *redis.Client) service.UserMsgQueueCache { + return &userMsgQueueCache{rdb: rdb} +} + +func umqLockKey(accountID int64) string { + // 格式: umq:{123}:lock — 花括号确保 Redis Cluster hash tag 生效 + return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLockSuffix +} + +func umqLastKey(accountID int64) string { + // 格式: umq:{123}:last — 与 lockKey 同一 hash slot + return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLastSuffix +} + +// umqScanPattern 用于 SCAN 扫描锁 key +func umqScanPattern() string { + return umqKeyPrefix + "{*}" + umqLockSuffix +} + +// AcquireLock 尝试获取账号级串行锁 +func (c *userMsgQueueCache) AcquireLock(ctx context.Context, accountID int64, requestID string, lockTtlMs int) (bool, error) { + key := umqLockKey(accountID) + result, err := acquireLockScript.Run(ctx, c.rdb, []string{key}, requestID, lockTtlMs).Int() + if err != nil { + return false, fmt.Errorf("umq acquire lock: %w", err) + } + return result == 1, nil +} + +// ReleaseLock 释放锁并记录完成时间 +func (c *userMsgQueueCache) ReleaseLock(ctx context.Context, accountID int64, requestID string) (bool, error) { + lockKey := umqLockKey(accountID) + lastKey := umqLastKey(accountID) + result, err := releaseLockScript.Run(ctx, c.rdb, []string{lockKey, lastKey}, requestID).Int() + if err != nil { + return false, fmt.Errorf("umq release lock: %w", err) + } + return result == 1, nil +} + +// GetLastCompletedMs 获取上次完成时间(毫秒时间戳) +func (c *userMsgQueueCache) GetLastCompletedMs(ctx context.Context, accountID int64) (int64, error) { + key := umqLastKey(accountID) + val, err := c.rdb.Get(ctx, key).Result() + if errors.Is(err, redis.Nil) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("umq get last completed: %w", err) + } + ms, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return 0, fmt.Errorf("umq parse last completed: %w", err) + } + return ms, nil +} + +// ForceReleaseLock 原子清理孤儿锁(仅在 PTTL == -1 时删除,防止 TOCTOU 竞态误删合法锁) +func (c *userMsgQueueCache) ForceReleaseLock(ctx context.Context, accountID int64) error { + key := umqLockKey(accountID) + _, err := forceReleaseLockScript.Run(ctx, c.rdb, []string{key}).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return fmt.Errorf("umq force release lock: %w", err) + } + return nil +} + +// ScanLockKeys 扫描所有锁 key,仅返回 PTTL == -1(无过期时间)的孤儿锁 accountID 列表 +// 正常的锁都有 PX 过期时间,PTTL == -1 表示异常状态(如 Redis 故障恢复后丢失 TTL) +func (c *userMsgQueueCache) ScanLockKeys(ctx context.Context, maxCount int) ([]int64, error) { + var accountIDs []int64 + var cursor uint64 + pattern := umqScanPattern() + + for { + keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, 100).Result() + if err != nil { + return nil, fmt.Errorf("umq scan lock keys: %w", err) + } + for _, key := range keys { + // 检查 PTTL:只清理 PTTL == -1(无过期时间)的异常锁 + pttl, err := c.rdb.PTTL(ctx, key).Result() + if err != nil { + continue + } + // PTTL 返回值:-2 = key 不存在,-1 = 无过期时间,>0 = 剩余毫秒 + // go-redis 对哨兵值 -1/-2 不乘精度系数,直接返回 time.Duration(-1)/-2 + // 只删除 -1(无过期时间的异常锁),跳过正常持有的锁 + if pttl != time.Duration(-1) { + continue + } + + // 从 key 中提取 accountID: umq:{123}:lock → 提取 {} 内的数字 + openBrace := strings.IndexByte(key, '{') + closeBrace := strings.IndexByte(key, '}') + if openBrace < 0 || closeBrace <= openBrace+1 { + continue + } + idStr := key[openBrace+1 : closeBrace] + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + continue + } + accountIDs = append(accountIDs, id) + if len(accountIDs) >= maxCount { + return accountIDs, nil + } + } + cursor = nextCursor + if cursor == 0 { + break + } + } + return accountIDs, nil +} + +// GetCurrentTimeMs 通过 Redis TIME 命令获取当前服务器时间(毫秒),确保与锁记录的时间源一致 +func (c *userMsgQueueCache) GetCurrentTimeMs(ctx context.Context) (int64, error) { + t, err := c.rdb.Time(ctx).Result() + if err != nil { + return 0, fmt.Errorf("umq get redis time: %w", err) + } + return t.UnixMilli(), nil +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 05b68968..b56aaaf9 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -243,21 +243,24 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. userMap[u.ID] = &outUsers[len(outUsers)-1] } - // Batch load active subscriptions with groups to avoid N+1. - subs, err := r.client.UserSubscription.Query(). - Where( - usersubscription.UserIDIn(userIDs...), - usersubscription.StatusEQ(service.SubscriptionStatusActive), - ). - WithGroup(). - All(ctx) - if err != nil { - return nil, nil, err - } + shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions + if shouldLoadSubscriptions { + // Batch load active subscriptions with groups to avoid N+1. + subs, err := r.client.UserSubscription.Query(). + Where( + usersubscription.UserIDIn(userIDs...), + usersubscription.StatusEQ(service.SubscriptionStatusActive), + ). + WithGroup(). + All(ctx) + if err != nil { + return nil, nil, err + } - for i := range subs { - if u, ok := userMap[subs[i].UserID]; ok { - u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i])) + for i := range subs { + if u, ok := userMap[subs[i].UserID]; ok { + u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i])) + } } } diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 2344035c..138bf59e 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -34,7 +34,7 @@ func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient // ProvidePricingRemoteClient 创建定价数据远程客户端 // 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据 func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient { - return NewPricingRemoteClient(cfg.Update.ProxyURL) + return NewPricingRemoteClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError) } // ProvideSessionLimitCache 创建会话限制缓存 @@ -53,13 +53,16 @@ var ProviderSet = wire.NewSet( NewAPIKeyRepository, NewGroupRepository, NewAccountRepository, - NewSoraAccountRepository, // Sora 账号扩展表仓储 + NewSoraAccountRepository, // Sora 账号扩展表仓储 + NewScheduledTestPlanRepository, // 定时测试计划仓储 + NewScheduledTestResultRepository, // 定时测试结果仓储 NewProxyRepository, NewRedeemCodeRepository, NewPromoCodeRepository, NewAnnouncementRepository, NewAnnouncementReadRepository, NewUsageLogRepository, + NewUsageBillingRepository, NewIdempotencyRepository, NewUsageCleanupRepository, NewDashboardAggregationRepository, @@ -80,6 +83,7 @@ var ProviderSet = wire.NewSet( ProvideConcurrencyCache, ProvideSessionLimitCache, NewRPMCache, + NewUserMsgQueueCache, NewDashboardCache, NewEmailCache, NewIdentityCache, @@ -96,6 +100,10 @@ var ProviderSet = wire.NewSet( // Encryptors NewAESEncryptor, + // Backup infrastructure + NewPgDumper, + NewS3BackupStoreFactory, + // HTTP service ports (DI Strategy A: return interface directly) NewTurnstileVerifier, ProvidePricingRemoteClient, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 2738ed18..6056d51f 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -86,6 +86,15 @@ func TestAPIContracts(t *testing.T) { "last_used_at": null, "quota": 0, "quota_used": 0, + "rate_limit_5h": 0, + "rate_limit_1d": 0, + "rate_limit_7d": 0, + "usage_5h": 0, + "usage_1d": 0, + "usage_7d": 0, + "window_5h_start": null, + "window_1d_start": null, + "window_7d_start": null, "expires_at": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" @@ -126,6 +135,15 @@ func TestAPIContracts(t *testing.T) { "last_used_at": null, "quota": 0, "quota_used": 0, + "rate_limit_5h": 0, + "rate_limit_1d": 0, + "rate_limit_7d": 0, + "usage_5h": 0, + "usage_1d": 0, + "usage_7d": 0, + "window_5h_start": null, + "window_1d_start": null, + "window_7d_start": null, "expires_at": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" @@ -192,8 +210,10 @@ func TestAPIContracts(t *testing.T) { "sora_video_price_per_request": null, "sora_video_price_per_request_hd": null, "claude_code_only": false, + "allow_messages_dispatch": false, "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, + "allow_messages_dispatch": false, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -428,9 +448,10 @@ func TestAPIContracts(t *testing.T) { setup: func(t *testing.T, deps *contractDeps) { t.Helper() deps.settingRepo.SetAll(map[string]string{ - service.SettingKeyRegistrationEnabled: "true", - service.SettingKeyEmailVerifyEnabled: "false", - service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyEmailVerifyEnabled: "false", + service.SettingKeyRegistrationEmailSuffixWhitelist: "[]", + service.SettingKeyPromoCodeEnabled: "true", service.SettingKeySMTPHost: "smtp.example.com", service.SettingKeySMTPPort: "587", @@ -469,8 +490,10 @@ func TestAPIContracts(t *testing.T) { "data": { "registration_enabled": true, "email_verify_enabled": false, + "registration_email_suffix_whitelist": [], "promo_code_enabled": true, "password_reset_enabled": false, + "frontend_url": "", "totp_enabled": false, "totp_encryption_key_configured": false, "smtp_host": "smtp.example.com", @@ -513,7 +536,10 @@ func TestAPIContracts(t *testing.T) { "hide_ccs_import_button": false, "purchase_subscription_enabled": false, "purchase_subscription_url": "", - "min_claude_code_version": "" + "min_claude_code_version": "", + "allow_ungrouped_key_scheduling": false, + "backend_mode_enabled": false, + "custom_menu_items": [] } }`, }, @@ -621,7 +647,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) @@ -1026,6 +1052,14 @@ func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte return nil, errors.New("not implemented") } +func (s *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return errors.New("not implemented") } @@ -1066,6 +1100,14 @@ func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map return errors.New("not implemented") } +func (s *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { s.bulkUpdateIDs = append([]int64{}, ids...) return int64(len(ids)), nil @@ -1383,7 +1425,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { return nil } -func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { +func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { ids := make([]int64, 0, len(r.byID)) for id := range r.byID { if r.byID[id].UserID == userID { @@ -1497,6 +1539,16 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti return nil } +func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + return nil +} +func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error { + return nil +} +func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) { + return nil, nil +} + type stubUsageLogRepo struct { userLogs map[int64][]service.UsageLog } @@ -1573,6 +1625,14 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi return nil, errors.New("not implemented") } +func (r *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) { + return nil, errors.New("not implemented") +} + func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { return nil, errors.New("not implemented") } @@ -1585,6 +1645,10 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end return nil, errors.New("not implemented") } +func (r *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) { + return nil, errors.New("not implemented") +} + func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { logs := r.userLogs[userID] if len(logs) == 0 { diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index 033a5b77..138663c4 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} - authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) admin := &service.User{ ID: 1, diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 19f97239..972c1eaf 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -19,8 +19,16 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS } // apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证) +// +// 中间件职责分为两层: +// - 鉴权(Authentication):验证 Key 有效性、用户状态、IP 限制 —— 始终执行 +// - 计费执行(Billing Enforcement):过期/配额/订阅/余额检查 —— skipBilling 时整块跳过 +// +// /v1/usage 端点只需鉴权,不需要计费执行(允许过期/配额耗尽的 Key 查询自身用量)。 func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { + // ── 1. 提取 API Key ────────────────────────────────────────── + queryKey := strings.TrimSpace(c.Query("key")) queryApiKey := strings.TrimSpace(c.Query("api_key")) if queryKey != "" || queryApiKey != "" { @@ -56,7 +64,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } - // 从数据库验证API key + // ── 2. 验证 Key 存在 ───────────────────────────────────────── + apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) if err != nil { if errors.Is(err, service.ErrAPIKeyNotFound) { @@ -67,29 +76,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } - // 检查API key是否激活 - if !apiKey.IsActive() { - // Provide more specific error message based on status - switch apiKey.Status { - case service.StatusAPIKeyQuotaExhausted: - AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") - case service.StatusAPIKeyExpired: - AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") - default: - AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") - } - return - } + // ── 3. 基础鉴权(始终执行) ───────────────────────────────── - // 检查API Key是否过期(即使状态是active,也要检查时间) - if apiKey.IsExpired() { - AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") - return - } - - // 检查API Key配额是否耗尽 - if apiKey.IsQuotaExhausted() { - AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") + // disabled / 未知状态 → 无条件拦截(expired 和 quota_exhausted 留给计费阶段) + if !apiKey.IsActive() && + apiKey.Status != service.StatusAPIKeyExpired && + apiKey.Status != service.StatusAPIKeyQuotaExhausted { + AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") return } @@ -116,8 +109,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } + // ── 4. SimpleMode → early return ───────────────────────────── + if cfg.RunMode == config.RunModeSimple { - // 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文 c.Set(string(ContextKeyAPIKey), apiKey) c.Set(string(ContextKeyUser), AuthSubject{ UserID: apiKey.User.ID, @@ -130,54 +124,89 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } - // 判断计费方式:订阅模式 vs 余额模式 + // ── 5. 加载订阅(订阅模式时始终加载) ─────────────────────── + + // skipBilling: /v1/usage 只需鉴权,跳过所有计费执行 + skipBilling := c.Request.URL.Path == "/v1/usage" + + var subscription *service.UserSubscription isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() if isSubscriptionType && subscriptionService != nil { - // 订阅模式:获取订阅(L1 缓存 + singleflight) - subscription, err := subscriptionService.GetActiveSubscription( + sub, subErr := subscriptionService.GetActiveSubscription( c.Request.Context(), apiKey.User.ID, apiKey.Group.ID, ) - if err != nil { - AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group") - return - } - - // 合并验证 + 限额检查(纯内存操作) - needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group) - if err != nil { - code := "SUBSCRIPTION_INVALID" - status := 403 - if errors.Is(err, service.ErrDailyLimitExceeded) || - errors.Is(err, service.ErrWeeklyLimitExceeded) || - errors.Is(err, service.ErrMonthlyLimitExceeded) { - code = "USAGE_LIMIT_EXCEEDED" - status = 429 + if subErr != nil { + if !skipBilling { + AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group") + return } - AbortWithError(c, status, code, err.Error()) - return - } - - // 将订阅信息存入上下文 - c.Set(string(ContextKeySubscription), subscription) - - // 窗口维护异步化(不阻塞请求) - // 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race - if needsMaintenance { - maintenanceCopy := *subscription - subscriptionService.DoWindowMaintenance(&maintenanceCopy) - } - } else { - // 余额模式:检查用户余额 - if apiKey.User.Balance <= 0 { - AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance") - return + // skipBilling: 订阅不存在也放行,handler 会返回可用的数据 + } else { + subscription = sub } } - // 将API key和用户信息存入上下文 + // ── 6. 计费执行(skipBilling 时整块跳过) ──────────────────── + + if !skipBilling { + // Key 状态检查 + switch apiKey.Status { + case service.StatusAPIKeyQuotaExhausted: + AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") + return + case service.StatusAPIKeyExpired: + AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") + return + } + + // 运行时过期/配额检查(即使状态是 active,也要检查时间和用量) + if apiKey.IsExpired() { + AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") + return + } + if apiKey.IsQuotaExhausted() { + AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") + return + } + + // 订阅模式:验证订阅限额 + if subscription != nil { + needsMaintenance, validateErr := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group) + if validateErr != nil { + code := "SUBSCRIPTION_INVALID" + status := 403 + if errors.Is(validateErr, service.ErrDailyLimitExceeded) || + errors.Is(validateErr, service.ErrWeeklyLimitExceeded) || + errors.Is(validateErr, service.ErrMonthlyLimitExceeded) { + code = "USAGE_LIMIT_EXCEEDED" + status = 429 + } + AbortWithError(c, status, code, validateErr.Error()) + return + } + + // 窗口维护异步化(不阻塞请求) + if needsMaintenance { + maintenanceCopy := *subscription + subscriptionService.DoWindowMaintenance(&maintenanceCopy) + } + } else { + // 非订阅模式 或 订阅模式但 subscriptionService 未注入:回退到余额检查 + if apiKey.User.Balance <= 0 { + AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance") + return + } + } + } + + // ── 7. 设置上下文 → Next ───────────────────────────────────── + + if subscription != nil { + c.Set(string(ContextKeySubscription), subscription) + } c.Set(string(ContextKeyAPIKey), apiKey) c.Set(string(ContextKeyUser), AuthSubject{ UserID: apiKey.User.ID, diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 2124c86c..49db5f19 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -56,7 +56,7 @@ func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error { func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } -func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { +func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { @@ -95,6 +95,15 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim } return nil } +func (f fakeAPIKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + return nil +} +func (f fakeAPIKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error { + return nil +} +func (f fakeAPIKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) { + return &service.APIKeyRateLimitData{}, nil +} func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { return errors.New("not implemented") diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 0d331761..22befa2a 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -537,7 +537,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } -func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { +func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } @@ -588,6 +588,16 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti return nil } +func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + return nil +} +func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error { + return nil +} +func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) { + return nil, nil +} + type stubUserSubscriptionRepo struct { getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) updateStatus func(ctx context.Context, subscriptionID int64, status string) error diff --git a/backend/internal/server/middleware/backend_mode_guard.go b/backend/internal/server/middleware/backend_mode_guard.go new file mode 100644 index 00000000..46482af3 --- /dev/null +++ b/backend/internal/server/middleware/backend_mode_guard.go @@ -0,0 +1,51 @@ +package middleware + +import ( + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled. +// Must be placed AFTER JWT auth middleware so that the user role is available in context. +func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc { + return func(c *gin.Context) { + if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) { + c.Next() + return + } + role, _ := GetUserRoleFromContext(c) + if role == "admin" { + c.Next() + return + } + response.Forbidden(c, "Backend mode is active. User self-service is disabled.") + c.Abort() + } +} + +// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled. +// Allows: login, login/2fa, logout, refresh (admin needs these). +// Blocks: register, forgot-password, reset-password, OAuth, etc. +func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc { + return func(c *gin.Context) { + if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) { + c.Next() + return + } + path := c.Request.URL.Path + // Allow login, 2FA, logout, refresh, public settings + allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} + for _, suffix := range allowedSuffixes { + if strings.HasSuffix(path, suffix) { + c.Next() + return + } + } + response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.") + c.Abort() + } +} diff --git a/backend/internal/server/middleware/backend_mode_guard_test.go b/backend/internal/server/middleware/backend_mode_guard_test.go new file mode 100644 index 00000000..8878ebc9 --- /dev/null +++ b/backend/internal/server/middleware/backend_mode_guard_test.go @@ -0,0 +1,239 @@ +//go:build unit + +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type bmSettingRepo struct { + values map[string]string +} + +func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) { + v, ok := r.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return v, nil +} + +func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error { + panic("unexpected Set call") +} + +func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error { + if r.values == nil { + r.values = make(map[string]string, len(settings)) + } + for key, value := range settings { + r.values[key] = value + } + return nil +} + +func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (r *bmSettingRepo) Delete(_ context.Context, _ string) error { + panic("unexpected Delete call") +} + +func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService { + t.Helper() + + repo := &bmSettingRepo{ + values: map[string]string{ + service.SettingKeyBackendModeEnabled: enabled, + }, + } + svc := service.NewSettingService(repo, &config.Config{}) + require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{ + BackendModeEnabled: enabled == "true", + })) + + return svc +} + +func stringPtr(v string) *string { + return &v +} + +func TestBackendModeUserGuard(t *testing.T) { + tests := []struct { + name string + nilService bool + enabled string + role *string + wantStatus int + }{ + { + name: "disabled_allows_all", + enabled: "false", + role: stringPtr("user"), + wantStatus: http.StatusOK, + }, + { + name: "nil_service_allows_all", + nilService: true, + role: stringPtr("user"), + wantStatus: http.StatusOK, + }, + { + name: "enabled_admin_allowed", + enabled: "true", + role: stringPtr("admin"), + wantStatus: http.StatusOK, + }, + { + name: "enabled_user_blocked", + enabled: "true", + role: stringPtr("user"), + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_no_role_blocked", + enabled: "true", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_empty_role_blocked", + enabled: "true", + role: stringPtr(""), + wantStatus: http.StatusForbidden, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + if tc.role != nil { + role := *tc.role + r.Use(func(c *gin.Context) { + c.Set(string(ContextKeyUserRole), role) + c.Next() + }) + } + + var svc *service.SettingService + if !tc.nilService { + svc = newBackendModeSettingService(t, tc.enabled) + } + + r.Use(BackendModeUserGuard(svc)) + r.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + r.ServeHTTP(w, req) + + require.Equal(t, tc.wantStatus, w.Code) + }) + } +} + +func TestBackendModeAuthGuard(t *testing.T) { + tests := []struct { + name string + nilService bool + enabled string + path string + wantStatus int + }{ + { + name: "disabled_allows_all", + enabled: "false", + path: "/api/v1/auth/register", + wantStatus: http.StatusOK, + }, + { + name: "nil_service_allows_all", + nilService: true, + path: "/api/v1/auth/register", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_login", + enabled: "true", + path: "/api/v1/auth/login", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_login_2fa", + enabled: "true", + path: "/api/v1/auth/login/2fa", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_logout", + enabled: "true", + path: "/api/v1/auth/logout", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_refresh", + enabled: "true", + path: "/api/v1/auth/refresh", + wantStatus: http.StatusOK, + }, + { + name: "enabled_blocks_register", + enabled: "true", + path: "/api/v1/auth/register", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_blocks_forgot_password", + enabled: "true", + path: "/api/v1/auth/forgot-password", + wantStatus: http.StatusForbidden, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + + var svc *service.SettingService + if !tc.nilService { + svc = newBackendModeSettingService(t, tc.enabled) + } + + r.Use(BackendModeAuthGuard(svc)) + r.Any("/*path", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, tc.path, nil) + r.ServeHTTP(w, req) + + require.Equal(t, tc.wantStatus, w.Code) + }) + } +} diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index f8839cfe..ad9c1b5b 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: users} - authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) diff --git a/backend/internal/server/middleware/middleware.go b/backend/internal/server/middleware/middleware.go index 26572019..27985cf8 100644 --- a/backend/internal/server/middleware/middleware.go +++ b/backend/internal/server/middleware/middleware.go @@ -2,8 +2,11 @@ package middleware import ( "context" + "net/http" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) @@ -71,3 +74,48 @@ func AbortWithError(c *gin.Context, statusCode int, code, message string) { c.JSON(statusCode, NewErrorResponse(code, message)) c.Abort() } + +// ────────────────────────────────────────────────────────── +// RequireGroupAssignment — 未分组 Key 拦截中间件 +// ────────────────────────────────────────────────────────── + +// GatewayErrorWriter 定义网关错误响应格式(不同协议使用不同格式) +type GatewayErrorWriter func(c *gin.Context, status int, message string) + +// AnthropicErrorWriter 按 Anthropic API 规范输出错误 +func AnthropicErrorWriter(c *gin.Context, status int, message string) { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{"type": "permission_error", "message": message}, + }) +} + +// GoogleErrorWriter 按 Google API 规范输出错误 +func GoogleErrorWriter(c *gin.Context, status int, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "code": status, + "message": message, + "status": googleapi.HTTPStatusToGoogleStatus(status), + }, + }) +} + +// RequireGroupAssignment 检查 API Key 是否已分配到分组, +// 如果未分组且系统设置不允许未分组 Key 调度则返回 403。 +func RequireGroupAssignment(settingService *service.SettingService, writeError GatewayErrorWriter) gin.HandlerFunc { + return func(c *gin.Context) { + apiKey, ok := GetAPIKeyFromContext(c) + if !ok || apiKey.GroupID != nil { + c.Next() + return + } + // 未分组 Key — 检查系统设置 + if settingService.IsUngroupedKeySchedulingAllowed(c.Request.Context()) { + c.Next() + return + } + writeError(c, http.StatusForbidden, "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.") + c.Abort() + } +} diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index f061db90..d9ec951e 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -41,7 +41,9 @@ func GetNonceFromContext(c *gin.Context) string { } // SecurityHeaders sets baseline security headers for all responses. -func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { +// getFrameSrcOrigins is an optional function that returns extra origins to inject into frame-src; +// pass nil to disable dynamic frame-src injection. +func SecurityHeaders(cfg config.CSPConfig, getFrameSrcOrigins func() []string) gin.HandlerFunc { policy := strings.TrimSpace(cfg.Policy) if policy == "" { policy = config.DefaultCSPPolicy @@ -51,6 +53,15 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { policy = enhanceCSPPolicy(policy) return func(c *gin.Context) { + finalPolicy := policy + if getFrameSrcOrigins != nil { + for _, origin := range getFrameSrcOrigins() { + if origin != "" { + finalPolicy = addToDirective(finalPolicy, "frame-src", origin) + } + } + } + c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Frame-Options", "DENY") c.Header("Referrer-Policy", "strict-origin-when-cross-origin") @@ -65,12 +76,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { if err != nil { // crypto/rand 失败时降级为无 nonce 的 CSP 策略 log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err) - finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'unsafe-inline'") - c.Header("Content-Security-Policy", finalPolicy) + c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'unsafe-inline'")) } else { c.Set(CSPNonceKey, nonce) - finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'") - c.Header("Content-Security-Policy", finalPolicy) + c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'nonce-"+nonce+"'")) } } c.Next() diff --git a/backend/internal/server/middleware/security_headers_test.go b/backend/internal/server/middleware/security_headers_test.go index 5a779825..031385d0 100644 --- a/backend/internal/server/middleware/security_headers_test.go +++ b/backend/internal/server/middleware/security_headers_test.go @@ -84,7 +84,7 @@ func TestGetNonceFromContext(t *testing.T) { func TestSecurityHeaders(t *testing.T) { t.Run("sets_basic_security_headers", func(t *testing.T) { cfg := config.CSPConfig{Enabled: false} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -99,7 +99,7 @@ func TestSecurityHeaders(t *testing.T) { t.Run("csp_disabled_no_csp_header", func(t *testing.T) { cfg := config.CSPConfig{Enabled: false} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -115,7 +115,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "default-src 'self'", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -136,7 +136,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -156,7 +156,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src 'self' __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -180,7 +180,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -199,7 +199,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: " \t\n ", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -217,7 +217,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -235,7 +235,7 @@ func TestSecurityHeaders(t *testing.T) { t.Run("calls_next_handler", func(t *testing.T) { cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) nextCalled := false router := gin.New() @@ -258,7 +258,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) nonces := make(map[string]bool) for i := 0; i < 10; i++ { @@ -376,7 +376,7 @@ func BenchmarkSecurityHeadersMiddleware(b *testing.B) { Enabled: true, Policy: "script-src 'self' __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index 07b51f23..99701531 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -1,7 +1,10 @@ package server import ( + "context" "log" + "sync/atomic" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" @@ -14,6 +17,8 @@ import ( "github.com/redis/go-redis/v9" ) +const frameSrcRefreshTimeout = 5 * time.Second + // SetupRouter 配置路由器中间件和路由 func SetupRouter( r *gin.Engine, @@ -28,11 +33,33 @@ func SetupRouter( cfg *config.Config, redisClient *redis.Client, ) *gin.Engine { + // 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src + var cachedFrameOrigins atomic.Pointer[[]string] + emptyOrigins := []string{} + cachedFrameOrigins.Store(&emptyOrigins) + + refreshFrameOrigins := func() { + ctx, cancel := context.WithTimeout(context.Background(), frameSrcRefreshTimeout) + defer cancel() + origins, err := settingService.GetFrameSrcOrigins(ctx) + if err != nil { + // 获取失败时保留已有缓存,避免 frame-src 被意外清空 + return + } + cachedFrameOrigins.Store(&origins) + } + refreshFrameOrigins() // 启动时初始化 + // 应用中间件 r.Use(middleware2.RequestLogger()) r.Use(middleware2.Logger()) r.Use(middleware2.CORS(cfg.CORS)) - r.Use(middleware2.SecurityHeaders(cfg.Security.CSP)) + r.Use(middleware2.SecurityHeaders(cfg.Security.CSP, func() []string { + if p := cachedFrameOrigins.Load(); p != nil { + return *p + } + return nil + })) // Serve embedded frontend with settings injection if available if web.HasEmbeddedFrontend() { @@ -40,15 +67,21 @@ func SetupRouter( if err != nil { log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err) r.Use(web.ServeEmbeddedFrontend()) + settingService.SetOnUpdateCallback(refreshFrameOrigins) } else { - // Register cache invalidation callback - settingService.SetOnUpdateCallback(frontendServer.InvalidateCache) + // Register combined callback: invalidate HTML cache + refresh frame origins + settingService.SetOnUpdateCallback(func() { + frontendServer.InvalidateCache() + refreshFrameOrigins() + }) r.Use(frontendServer.Middleware()) } + } else { + settingService.SetOnUpdateCallback(refreshFrameOrigins) } // 注册路由 - registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient) + registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) return r } @@ -63,6 +96,7 @@ func registerRoutes( apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, opsService *service.OpsService, + settingService *service.SettingService, cfg *config.Config, redisClient *redis.Client, ) { @@ -73,9 +107,9 @@ func registerRoutes( v1 := r.Group("/api/v1") // 注册各模块路由 - routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient) - routes.RegisterUserRoutes(v1, h, jwtAuth) - routes.RegisterSoraClientRoutes(v1, h, jwtAuth) + routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService) + routes.RegisterUserRoutes(v1, h, jwtAuth, settingService) + routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService) routes.RegisterAdminRoutes(v1, h, adminAuth) - routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg) + routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index c36c36a0..85bfa6a6 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -58,6 +58,9 @@ func RegisterAdminRoutes( // 数据管理 registerDataManagementRoutes(admin, h) + // 数据库备份恢复 + registerBackupRoutes(admin, h) + // 运维监控(Ops) registerOpsRoutes(admin, h) @@ -78,6 +81,9 @@ func RegisterAdminRoutes( // API Key 管理 registerAdminAPIKeyRoutes(admin, h) + + // 定时测试计划 + registerScheduledTestRoutes(admin, h) } } @@ -168,6 +174,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth) // Dashboard (vNext - raw path for MVP) + ops.GET("/dashboard/snapshot-v2", h.Admin.Ops.GetDashboardSnapshotV2) ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview) ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend) ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram) @@ -180,6 +187,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { dashboard := admin.Group("/dashboard") { + dashboard.GET("/snapshot-v2", h.Admin.Dashboard.GetSnapshotV2) dashboard.GET("/stats", h.Admin.Dashboard.GetStats) dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics) dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend) @@ -187,6 +195,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats) dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend) dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) + dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking) dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage) dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation) @@ -223,6 +232,9 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { groups.PUT("/:id", h.Admin.Group.Update) groups.DELETE("/:id", h.Admin.Group.Delete) groups.GET("/:id/stats", h.Admin.Group.GetStats) + groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers) + groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers) + groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers) groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys) } } @@ -239,6 +251,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.PUT("/:id", h.Admin.Account.Update) accounts.DELETE("/:id", h.Admin.Account.Delete) accounts.POST("/:id/test", h.Admin.Account.Test) + accounts.POST("/:id/recover-state", h.Admin.Account.RecoverState) accounts.POST("/:id/refresh", h.Admin.Account.Refresh) accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier) accounts.GET("/:id/stats", h.Admin.Account.GetStats) @@ -247,6 +260,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats) accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats) accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit) + accounts.POST("/:id/reset-quota", h.Admin.Account.ResetQuota) accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable) accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable) accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) @@ -257,6 +271,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials) accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier) accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate) + accounts.POST("/batch-clear-error", h.Admin.Account.BatchClearError) + accounts.POST("/batch-refresh", h.Admin.Account.BatchRefresh) // Antigravity 默认模型映射 accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping) @@ -386,6 +402,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // 流超时处理配置 adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) + // 请求整流器配置 + adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings) + adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings) + // Beta 策略配置 + adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings) + adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings) // Sora S3 存储配置 adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings) adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings) @@ -421,6 +443,30 @@ func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } +func registerBackupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + backup := admin.Group("/backups") + { + // S3 存储配置 + backup.GET("/s3-config", h.Admin.Backup.GetS3Config) + backup.PUT("/s3-config", h.Admin.Backup.UpdateS3Config) + backup.POST("/s3-config/test", h.Admin.Backup.TestS3Connection) + + // 定时备份配置 + backup.GET("/schedule", h.Admin.Backup.GetSchedule) + backup.PUT("/schedule", h.Admin.Backup.UpdateSchedule) + + // 备份操作 + backup.POST("", h.Admin.Backup.CreateBackup) + backup.GET("", h.Admin.Backup.ListBackups) + backup.GET("/:id", h.Admin.Backup.GetBackup) + backup.DELETE("/:id", h.Admin.Backup.DeleteBackup) + backup.GET("/:id/download-url", h.Admin.Backup.GetDownloadURL) + + // 恢复操作 + backup.POST("/:id/restore", h.Admin.Backup.RestoreBackup) + } +} + func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) { system := admin.Group("/system") { @@ -441,6 +487,7 @@ func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) { subscriptions.POST("/assign", h.Admin.Subscription.Assign) subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign) subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend) + subscriptions.POST("/:id/reset-quota", h.Admin.Subscription.ResetQuota) subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke) } @@ -476,6 +523,18 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } +func registerScheduledTestRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + plans := admin.Group("/scheduled-test-plans") + { + plans.POST("", h.Admin.ScheduledTest.Create) + plans.PUT("/:id", h.Admin.ScheduledTest.Update) + plans.DELETE("/:id", h.Admin.ScheduledTest.Delete) + plans.GET("/:id/results", h.Admin.ScheduledTest.ListResults) + } + // Nested under accounts + admin.GET("/accounts/:id/scheduled-test-plans", h.Admin.ScheduledTest.ListByAccount) +} + func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) { rules := admin.Group("/error-passthrough-rules") { diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index c168820c..a6c0ecf5 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -6,6 +6,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/middleware" servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" @@ -17,12 +18,14 @@ func RegisterAuthRoutes( h *handler.Handlers, jwtAuth servermiddleware.JWTAuthMiddleware, redisClient *redis.Client, + settingService *service.SettingService, ) { // 创建速率限制器 rateLimiter := middleware.NewRateLimiter(redisClient) // 公开接口 auth := v1.Group("/auth") + auth.Use(servermiddleware.BackendModeAuthGuard(settingService)) { // 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close) auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{ @@ -61,6 +64,12 @@ func RegisterAuthRoutes( }), h.Auth.ResetPassword) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) + auth.POST("/oauth/linuxdo/complete-registration", + rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CompleteLinuxDoOAuthRegistration, + ) } // 公开设置(无需认证) @@ -72,6 +81,7 @@ func RegisterAuthRoutes( // 需要认证的当前用户信息 authenticated := v1.Group("") authenticated.Use(gin.HandlerFunc(jwtAuth)) + authenticated.Use(servermiddleware.BackendModeUserGuard(settingService)) { authenticated.GET("/auth/me", h.Auth.GetCurrentUser) // 撤销所有会话(需要认证) diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go index 5ce8497c..4f411cec 100644 --- a/backend/internal/server/routes/auth_rate_limit_test.go +++ b/backend/internal/server/routes/auth_rate_limit_test.go @@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine { c.Next() }), redisClient, + nil, ) return router diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 6bd91b85..fe820830 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -19,6 +19,7 @@ func RegisterGatewayRoutes( apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, opsService *service.OpsService, + settingService *service.SettingService, cfg *config.Config, ) { bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) @@ -29,30 +30,51 @@ func RegisterGatewayRoutes( soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize) clientRequestID := middleware.ClientRequestID() opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) + endpointNorm := handler.InboundEndpointMiddleware() + + // 未分组 Key 拦截中间件(按协议格式区分错误响应) + requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter) + requireGroupGoogle := middleware.RequireGroupAssignment(settingService, middleware.GoogleErrorWriter) // API网关(Claude API兼容) gateway := r.Group("/v1") gateway.Use(bodyLimit) gateway.Use(clientRequestID) gateway.Use(opsErrorLogger) + gateway.Use(endpointNorm) gateway.Use(gin.HandlerFunc(apiKeyAuth)) + gateway.Use(requireGroupAnthropic) { - gateway.POST("/messages", h.Gateway.Messages) - gateway.POST("/messages/count_tokens", h.Gateway.CountTokens) + // /v1/messages: auto-route based on group platform + gateway.POST("/messages", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.Messages(c) + return + } + h.Gateway.Messages(c) + }) + // /v1/messages/count_tokens: OpenAI groups get 404 + gateway.POST("/messages/count_tokens", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + c.JSON(http.StatusNotFound, gin.H{ + "type": "error", + "error": gin.H{ + "type": "not_found_error", + "message": "Token counting is not supported for this platform", + }, + }) + return + } + h.Gateway.CountTokens(c) + }) gateway.GET("/models", h.Gateway.Models) gateway.GET("/usage", h.Gateway.Usage) // OpenAI Responses API gateway.POST("/responses", h.OpenAIGateway.Responses) + gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses) gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) - // 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。 - gateway.POST("/chat/completions", func(c *gin.Context) { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "type": "invalid_request_error", - "message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.", - }, - }) - }) + // OpenAI Chat Completions API + gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) @@ -60,7 +82,9 @@ func RegisterGatewayRoutes( gemini.Use(bodyLimit) gemini.Use(clientRequestID) gemini.Use(opsErrorLogger) + gemini.Use(endpointNorm) gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) + gemini.Use(requireGroupGoogle) { gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) @@ -69,19 +93,24 @@ func RegisterGatewayRoutes( } // OpenAI Responses API(不带v1前缀的别名) - r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) - r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ResponsesWebSocket) + r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) + r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) + r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) + // OpenAI Chat Completions API(不带v1前缀的别名) + r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions) // Antigravity 模型列表 - r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels) + r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) // Antigravity 专用路由(仅使用 antigravity 账户,不混合调度) antigravityV1 := r.Group("/antigravity/v1") antigravityV1.Use(bodyLimit) antigravityV1.Use(clientRequestID) antigravityV1.Use(opsErrorLogger) + antigravityV1.Use(endpointNorm) antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1.Use(gin.HandlerFunc(apiKeyAuth)) + antigravityV1.Use(requireGroupAnthropic) { antigravityV1.POST("/messages", h.Gateway.Messages) antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens) @@ -93,8 +122,10 @@ func RegisterGatewayRoutes( antigravityV1Beta.Use(bodyLimit) antigravityV1Beta.Use(clientRequestID) antigravityV1Beta.Use(opsErrorLogger) + antigravityV1Beta.Use(endpointNorm) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) + antigravityV1Beta.Use(requireGroupGoogle) { antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels) antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) @@ -106,8 +137,10 @@ func RegisterGatewayRoutes( soraV1.Use(soraBodyLimit) soraV1.Use(clientRequestID) soraV1.Use(opsErrorLogger) + soraV1.Use(endpointNorm) soraV1.Use(middleware.ForcePlatform(service.PlatformSora)) soraV1.Use(gin.HandlerFunc(apiKeyAuth)) + soraV1.Use(requireGroupAnthropic) { soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions) soraV1.GET("/models", h.Gateway.Models) @@ -122,3 +155,12 @@ func RegisterGatewayRoutes( // Sora 媒体代理(签名 URL,无需 API Key) r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned) } + +// getGroupPlatform extracts the group platform from the API Key stored in context. +func getGroupPlatform(c *gin.Context) string { + apiKey, ok := middleware.GetAPIKeyFromContext(c) + if !ok || apiKey.Group == nil { + return "" + } + return apiKey.Group.Platform +} diff --git a/backend/internal/server/routes/gateway_test.go b/backend/internal/server/routes/gateway_test.go new file mode 100644 index 00000000..00edd31b --- /dev/null +++ b/backend/internal/server/routes/gateway_test.go @@ -0,0 +1,51 @@ +package routes + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/handler" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func newGatewayRoutesTestRouter() *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + + RegisterGatewayRoutes( + router, + &handler.Handlers{ + Gateway: &handler.GatewayHandler{}, + OpenAIGateway: &handler.OpenAIGatewayHandler{}, + SoraGateway: &handler.SoraGatewayHandler{}, + }, + servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) { + c.Next() + }), + nil, + nil, + nil, + nil, + &config.Config{}, + ) + + return router +} + +func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) { + router := newGatewayRoutesTestRouter() + + for _, path := range []string{"/v1/responses/compact", "/responses/compact"} { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-5"}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path) + } +} diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go index 40ae0436..13fceb81 100644 --- a/backend/internal/server/routes/sora_client.go +++ b/backend/internal/server/routes/sora_client.go @@ -3,6 +3,7 @@ package routes import ( "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) @@ -12,6 +13,7 @@ func RegisterSoraClientRoutes( v1 *gin.RouterGroup, h *handler.Handlers, jwtAuth middleware.JWTAuthMiddleware, + settingService *service.SettingService, ) { if h.SoraClient == nil { return @@ -19,6 +21,7 @@ func RegisterSoraClientRoutes( authenticated := v1.Group("/sora") authenticated.Use(gin.HandlerFunc(jwtAuth)) + authenticated.Use(middleware.BackendModeUserGuard(settingService)) { authenticated.POST("/generate", h.SoraClient.Generate) authenticated.GET("/generations", h.SoraClient.ListGenerations) diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index d0ed2489..c3b82742 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -3,6 +3,7 @@ package routes import ( "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) @@ -12,9 +13,11 @@ func RegisterUserRoutes( v1 *gin.RouterGroup, h *handler.Handlers, jwtAuth middleware.JWTAuthMiddleware, + settingService *service.SettingService, ) { authenticated := v1.Group("") authenticated.Use(gin.HandlerFunc(jwtAuth)) + authenticated.Use(middleware.BackendModeUserGuard(settingService)) { // 用户接口 user := authenticated.Group("/user") diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index c76c817e..b6408f5f 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -3,6 +3,7 @@ package service import ( "encoding/json" + "errors" "hash/fnv" "reflect" "sort" @@ -10,6 +11,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/domain" ) @@ -27,6 +29,7 @@ type Account struct { // RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。 // 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。 RateMultiplier *float64 + LoadFactor *int // 调度负载因子;nil 表示使用 Concurrency Status string ErrorMessage string LastUsedAt *time.Time @@ -87,6 +90,19 @@ func (a *Account) BillingRateMultiplier() float64 { return *a.RateMultiplier } +func (a *Account) EffectiveLoadFactor() int { + if a == nil { + return 1 + } + if a.LoadFactor != nil && *a.LoadFactor > 0 { + return *a.LoadFactor + } + if a.Concurrency > 0 { + return a.Concurrency + } + return 1 +} + func (a *Account) IsSchedulable() bool { if !a.IsActive() || !a.Schedulable { return false @@ -397,6 +413,7 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri if a.Platform == domain.PlatformAntigravity { return domain.DefaultAntigravityModelMapping } + // Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整) return nil } if len(rawMapping) == 0 { @@ -506,16 +523,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool { // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) // 如果未配置 mapping,返回原始模型名 func (a *Account) GetMappedModel(requestedModel string) string { + mappedModel, _ := a.ResolveMappedModel(requestedModel) + return mappedModel +} + +// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。 +// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。 +func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) { mapping := a.GetModelMapping() if len(mapping) == 0 { - return requestedModel + return requestedModel, false } // 精确匹配优先 if mappedModel, exists := mapping[requestedModel]; exists { - return mappedModel + return mappedModel, true } // 通配符匹配(最长优先) - return matchWildcardMapping(mapping, requestedModel) + return matchWildcardMappingResult(mapping, requestedModel) } func (a *Account) GetBaseURL() string { @@ -589,9 +613,7 @@ func matchWildcard(pattern, str string) bool { return matchAntigravityWildcard(pattern, str) } -// matchWildcardMapping 通配符映射匹配(最长优先) -// 如果没有匹配,返回原始字符串 -func matchWildcardMapping(mapping map[string]string, requestedModel string) string { +func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) { // 收集所有匹配的 pattern,按长度降序排序(最长优先) type patternMatch struct { pattern string @@ -606,7 +628,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri } if len(matches) == 0 { - return requestedModel // 无匹配,返回原始模型名 + return requestedModel, false // 无匹配,返回原始模型名 } // 按 pattern 长度降序排序 @@ -617,7 +639,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri return matches[i].pattern < matches[j].pattern }) - return matches[0].target + return matches[0].target, true } func (a *Account) IsCustomErrorCodesEnabled() bool { @@ -632,6 +654,75 @@ func (a *Account) IsCustomErrorCodesEnabled() bool { return false } +// IsPoolMode 检查 API Key 账号是否启用池模式。 +// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。 +func (a *Account) IsPoolMode() bool { + if !a.IsAPIKeyOrBedrock() || a.Credentials == nil { + return false + } + if v, ok := a.Credentials["pool_mode"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + +const ( + defaultPoolModeRetryCount = 3 + maxPoolModeRetryCount = 10 +) + +// GetPoolModeRetryCount 返回池模式同账号重试次数。 +// 未配置或配置非法时回退为默认值 3;小于 0 按 0 处理;过大则截断到 10。 +func (a *Account) GetPoolModeRetryCount() int { + if a == nil || !a.IsPoolMode() || a.Credentials == nil { + return defaultPoolModeRetryCount + } + raw, ok := a.Credentials["pool_mode_retry_count"] + if !ok || raw == nil { + return defaultPoolModeRetryCount + } + count := parsePoolModeRetryCount(raw) + if count < 0 { + return 0 + } + if count > maxPoolModeRetryCount { + return maxPoolModeRetryCount + } + return count +} + +func parsePoolModeRetryCount(value any) int { + switch v := value.(type) { + case int: + return v + case int64: + return int(v) + case float64: + return int(v) + case json.Number: + if i, err := v.Int64(); err == nil { + return int(i) + } + case string: + if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil { + return i + } + } + return defaultPoolModeRetryCount +} + +// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码 +func isPoolModeRetryableStatus(statusCode int) bool { + switch statusCode { + case 401, 403, 429: + return true + default: + return false + } +} + func (a *Account) GetCustomErrorCodes() []int { if a.Credentials == nil { return nil @@ -680,6 +771,19 @@ func (a *Account) IsInterceptWarmupEnabled() bool { return false } +func (a *Account) IsBedrock() bool { + return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock +} + +func (a *Account) IsBedrockAPIKey() bool { + return a.IsBedrock() && a.GetCredential("auth_mode") == "apikey" +} + +// IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性 +func (a *Account) IsAPIKeyOrBedrock() bool { + return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock +} + func (a *Account) IsOpenAI() bool { return a.Platform == PlatformOpenAI } @@ -797,6 +901,22 @@ func (a *Account) IsMixedSchedulingEnabled() bool { return false } +// IsOveragesEnabled 检查 Antigravity 账号是否启用 AI Credits 超量请求。 +func (a *Account) IsOveragesEnabled() bool { + if a.Platform != PlatformAntigravity { + return false + } + if a.Extra == nil { + return false + } + if v, ok := a.Extra["allow_overages"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + // IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。 // // 新字段:accounts.extra.openai_passthrough。 @@ -852,15 +972,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool { } const ( - OpenAIWSIngressModeOff = "off" - OpenAIWSIngressModeShared = "shared" - OpenAIWSIngressModeDedicated = "dedicated" + OpenAIWSIngressModeOff = "off" + OpenAIWSIngressModeShared = "shared" + OpenAIWSIngressModeDedicated = "dedicated" + OpenAIWSIngressModeCtxPool = "ctx_pool" + OpenAIWSIngressModePassthrough = "passthrough" ) func normalizeOpenAIWSIngressMode(mode string) string { switch strings.ToLower(strings.TrimSpace(mode)) { case OpenAIWSIngressModeOff: return OpenAIWSIngressModeOff + case OpenAIWSIngressModeCtxPool: + return OpenAIWSIngressModeCtxPool + case OpenAIWSIngressModePassthrough: + return OpenAIWSIngressModePassthrough case OpenAIWSIngressModeShared: return OpenAIWSIngressModeShared case OpenAIWSIngressModeDedicated: @@ -872,18 +998,21 @@ func normalizeOpenAIWSIngressMode(mode string) string { func normalizeOpenAIWSIngressDefaultMode(mode string) string { if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" { + if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated { + return OpenAIWSIngressModeCtxPool + } return normalized } - return OpenAIWSIngressModeShared + return OpenAIWSIngressModeCtxPool } -// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。 +// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool/passthrough)。 // // 优先级: // 1. 分类型 mode 新字段(string) // 2. 分类型 enabled 旧字段(bool) // 3. 兼容 enabled 旧字段(bool) -// 4. defaultMode(非法时回退 shared) +// 4. defaultMode(非法时回退 ctx_pool) func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string { resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode) if a == nil || !a.IsOpenAI() { @@ -918,7 +1047,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri return "", false } if enabled { - return OpenAIWSIngressModeShared, true + return OpenAIWSIngressModeCtxPool, true } return OpenAIWSIngressModeOff, true } @@ -945,6 +1074,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri if mode, ok := resolveBoolMode("openai_ws_enabled"); ok { return mode } + // 兼容旧值:shared/dedicated 语义都归并到 ctx_pool。 + if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated { + return OpenAIWSIngressModeCtxPool + } return resolvedDefault } @@ -1032,6 +1165,26 @@ func (a *Account) IsTLSFingerprintEnabled() bool { return false } +// GetUserMsgQueueMode 获取用户消息队列模式 +// "serialize" = 串行队列, "throttle" = 软性限速, "" = 未设置(使用全局配置) +func (a *Account) GetUserMsgQueueMode() string { + if a.Extra == nil { + return "" + } + // 优先读取新字段 user_msg_queue_mode(白名单校验,非法值视为未设置) + if mode, ok := a.Extra["user_msg_queue_mode"].(string); ok && mode != "" { + if mode == config.UMQModeSerialize || mode == config.UMQModeThrottle { + return mode + } + return "" // 非法值 fallback 到全局配置 + } + // 向后兼容: user_msg_queue_enabled: true → "serialize" + if enabled, ok := a.Extra["user_msg_queue_enabled"].(bool); ok && enabled { + return config.UMQModeSerialize + } + return "" +} + // IsSessionIDMaskingEnabled 检查是否启用会话ID伪装 // 仅适用于 Anthropic OAuth/SetupToken 类型账号 // 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID, @@ -1083,6 +1236,348 @@ func (a *Account) GetCacheTTLOverrideTarget() string { return "5m" } +// GetQuotaLimit 获取 API Key 账号的配额限制(美元) +// 返回 0 表示未启用 +func (a *Account) GetQuotaLimit() float64 { + return a.getExtraFloat64("quota_limit") +} + +// GetQuotaUsed 获取 API Key 账号的已用配额(美元) +func (a *Account) GetQuotaUsed() float64 { + return a.getExtraFloat64("quota_used") +} + +// GetQuotaDailyLimit 获取日额度限制(美元),0 表示未启用 +func (a *Account) GetQuotaDailyLimit() float64 { + return a.getExtraFloat64("quota_daily_limit") +} + +// GetQuotaDailyUsed 获取当日已用额度(美元) +func (a *Account) GetQuotaDailyUsed() float64 { + return a.getExtraFloat64("quota_daily_used") +} + +// GetQuotaWeeklyLimit 获取周额度限制(美元),0 表示未启用 +func (a *Account) GetQuotaWeeklyLimit() float64 { + return a.getExtraFloat64("quota_weekly_limit") +} + +// GetQuotaWeeklyUsed 获取本周已用额度(美元) +func (a *Account) GetQuotaWeeklyUsed() float64 { + return a.getExtraFloat64("quota_weekly_used") +} + +// getExtraFloat64 从 Extra 中读取指定 key 的 float64 值 +func (a *Account) getExtraFloat64(key string) float64 { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra[key]; ok { + return parseExtraFloat64(v) + } + return 0 +} + +// getExtraTime 从 Extra 中读取 RFC3339 时间戳 +func (a *Account) getExtraTime(key string) time.Time { + if a.Extra == nil { + return time.Time{} + } + if v, ok := a.Extra[key]; ok { + if s, ok := v.(string); ok { + if t, err := time.Parse(time.RFC3339Nano, s); err == nil { + return t + } + if t, err := time.Parse(time.RFC3339, s); err == nil { + return t + } + } + } + return time.Time{} +} + +// getExtraString 从 Extra 中读取指定 key 的字符串值 +func (a *Account) getExtraString(key string) string { + if a.Extra == nil { + return "" + } + if v, ok := a.Extra[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +// getExtraInt 从 Extra 中读取指定 key 的 int 值 +func (a *Account) getExtraInt(key string) int { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra[key]; ok { + return int(parseExtraFloat64(v)) + } + return 0 +} + +// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed" +func (a *Account) GetQuotaDailyResetMode() string { + if m := a.getExtraString("quota_daily_reset_mode"); m == "fixed" { + return "fixed" + } + return "rolling" +} + +// GetQuotaDailyResetHour 获取固定重置的小时(0-23),默认 0 +func (a *Account) GetQuotaDailyResetHour() int { + return a.getExtraInt("quota_daily_reset_hour") +} + +// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed" +func (a *Account) GetQuotaWeeklyResetMode() string { + if m := a.getExtraString("quota_weekly_reset_mode"); m == "fixed" { + return "fixed" + } + return "rolling" +} + +// GetQuotaWeeklyResetDay 获取固定重置的星期几(0=周日, 1=周一, ..., 6=周六),默认 1(周一) +func (a *Account) GetQuotaWeeklyResetDay() int { + if a.Extra == nil { + return 1 + } + if _, ok := a.Extra["quota_weekly_reset_day"]; !ok { + return 1 + } + return a.getExtraInt("quota_weekly_reset_day") +} + +// GetQuotaWeeklyResetHour 获取周配额固定重置的小时(0-23),默认 0 +func (a *Account) GetQuotaWeeklyResetHour() int { + return a.getExtraInt("quota_weekly_reset_hour") +} + +// GetQuotaResetTimezone 获取固定重置的时区名(IANA),默认 "UTC" +func (a *Account) GetQuotaResetTimezone() string { + if tz := a.getExtraString("quota_reset_timezone"); tz != "" { + return tz + } + return "UTC" +} + +// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点 +func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time { + t := after.In(tz) + today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz) + if !after.Before(today) { + return today.AddDate(0, 0, 1) + } + return today +} + +// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点 +func lastFixedDailyReset(hour int, tz *time.Location, now time.Time) time.Time { + t := now.In(tz) + today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz) + if now.Before(today) { + return today.AddDate(0, 0, -1) + } + return today +} + +// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点 +// day: 0=Sunday, 1=Monday, ..., 6=Saturday +func nextFixedWeeklyReset(day, hour int, tz *time.Location, after time.Time) time.Time { + t := after.In(tz) + todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz) + currentDay := int(todayReset.Weekday()) + + daysForward := (day - currentDay + 7) % 7 + if daysForward == 0 && !after.Before(todayReset) { + daysForward = 7 + } + return todayReset.AddDate(0, 0, daysForward) +} + +// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点 +func lastFixedWeeklyReset(day, hour int, tz *time.Location, now time.Time) time.Time { + t := now.In(tz) + todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz) + currentDay := int(todayReset.Weekday()) + + daysBack := (currentDay - day + 7) % 7 + if daysBack == 0 && now.Before(todayReset) { + daysBack = 7 + } + return todayReset.AddDate(0, 0, -daysBack) +} + +// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期 +func (a *Account) isFixedDailyPeriodExpired(periodStart time.Time) bool { + if periodStart.IsZero() { + return true + } + tz, err := time.LoadLocation(a.GetQuotaResetTimezone()) + if err != nil { + tz = time.UTC + } + lastReset := lastFixedDailyReset(a.GetQuotaDailyResetHour(), tz, time.Now()) + return periodStart.Before(lastReset) +} + +// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期 +func (a *Account) isFixedWeeklyPeriodExpired(periodStart time.Time) bool { + if periodStart.IsZero() { + return true + } + tz, err := time.LoadLocation(a.GetQuotaResetTimezone()) + if err != nil { + tz = time.UTC + } + lastReset := lastFixedWeeklyReset(a.GetQuotaWeeklyResetDay(), a.GetQuotaWeeklyResetHour(), tz, time.Now()) + return periodStart.Before(lastReset) +} + +// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at +// 在保存账号配置时调用 +func ComputeQuotaResetAt(extra map[string]any) { + now := time.Now() + tzName, _ := extra["quota_reset_timezone"].(string) + if tzName == "" { + tzName = "UTC" + } + tz, err := time.LoadLocation(tzName) + if err != nil { + tz = time.UTC + } + + // 日配额固定重置时间 + if mode, _ := extra["quota_daily_reset_mode"].(string); mode == "fixed" { + hour := int(parseExtraFloat64(extra["quota_daily_reset_hour"])) + if hour < 0 || hour > 23 { + hour = 0 + } + resetAt := nextFixedDailyReset(hour, tz, now) + extra["quota_daily_reset_at"] = resetAt.UTC().Format(time.RFC3339) + } else { + delete(extra, "quota_daily_reset_at") + } + + // 周配额固定重置时间 + if mode, _ := extra["quota_weekly_reset_mode"].(string); mode == "fixed" { + day := 1 // 默认周一 + if d, ok := extra["quota_weekly_reset_day"]; ok { + day = int(parseExtraFloat64(d)) + } + if day < 0 || day > 6 { + day = 1 + } + hour := int(parseExtraFloat64(extra["quota_weekly_reset_hour"])) + if hour < 0 || hour > 23 { + hour = 0 + } + resetAt := nextFixedWeeklyReset(day, hour, tz, now) + extra["quota_weekly_reset_at"] = resetAt.UTC().Format(time.RFC3339) + } else { + delete(extra, "quota_weekly_reset_at") + } +} + +// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性 +func ValidateQuotaResetConfig(extra map[string]any) error { + if extra == nil { + return nil + } + // 校验时区 + if tz, ok := extra["quota_reset_timezone"].(string); ok && tz != "" { + if _, err := time.LoadLocation(tz); err != nil { + return errors.New("invalid quota_reset_timezone: must be a valid IANA timezone name") + } + } + // 日配额重置模式 + if mode, ok := extra["quota_daily_reset_mode"].(string); ok { + if mode != "rolling" && mode != "fixed" { + return errors.New("quota_daily_reset_mode must be 'rolling' or 'fixed'") + } + } + // 日配额重置小时 + if v, ok := extra["quota_daily_reset_hour"]; ok { + hour := int(parseExtraFloat64(v)) + if hour < 0 || hour > 23 { + return errors.New("quota_daily_reset_hour must be between 0 and 23") + } + } + // 周配额重置模式 + if mode, ok := extra["quota_weekly_reset_mode"].(string); ok { + if mode != "rolling" && mode != "fixed" { + return errors.New("quota_weekly_reset_mode must be 'rolling' or 'fixed'") + } + } + // 周配额重置星期几 + if v, ok := extra["quota_weekly_reset_day"]; ok { + day := int(parseExtraFloat64(v)) + if day < 0 || day > 6 { + return errors.New("quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)") + } + } + // 周配额重置小时 + if v, ok := extra["quota_weekly_reset_hour"]; ok { + hour := int(parseExtraFloat64(v)) + if hour < 0 || hour > 23 { + return errors.New("quota_weekly_reset_hour must be between 0 and 23") + } + } + return nil +} + +// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制 +func (a *Account) HasAnyQuotaLimit() bool { + return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0 +} + +// isPeriodExpired 检查指定周期(自 periodStart 起经过 dur)是否已过期 +func isPeriodExpired(periodStart time.Time, dur time.Duration) bool { + if periodStart.IsZero() { + return true // 从未使用过,视为过期(下次 increment 会初始化) + } + return time.Since(periodStart) >= dur +} + +// IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true) +func (a *Account) IsQuotaExceeded() bool { + // 总额度 + if limit := a.GetQuotaLimit(); limit > 0 && a.GetQuotaUsed() >= limit { + return true + } + // 日额度(周期过期视为未超限,下次 increment 会重置) + if limit := a.GetQuotaDailyLimit(); limit > 0 { + start := a.getExtraTime("quota_daily_start") + var expired bool + if a.GetQuotaDailyResetMode() == "fixed" { + expired = a.isFixedDailyPeriodExpired(start) + } else { + expired = isPeriodExpired(start, 24*time.Hour) + } + if !expired && a.GetQuotaDailyUsed() >= limit { + return true + } + } + // 周额度 + if limit := a.GetQuotaWeeklyLimit(); limit > 0 { + start := a.getExtraTime("quota_weekly_start") + var expired bool + if a.GetQuotaWeeklyResetMode() == "fixed" { + expired = a.isFixedWeeklyPeriodExpired(start) + } else { + expired = isPeriodExpired(start, 7*24*time.Hour) + } + if !expired && a.GetQuotaWeeklyUsed() >= limit { + return true + } + } + return false +} + // GetWindowCostLimit 获取 5h 窗口费用阈值(美元) // 返回 0 表示未启用 func (a *Account) GetWindowCostLimit() float64 { diff --git a/backend/internal/service/account_load_factor_test.go b/backend/internal/service/account_load_factor_test.go new file mode 100644 index 00000000..a4d78a4b --- /dev/null +++ b/backend/internal/service/account_load_factor_test.go @@ -0,0 +1,46 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func intPtrHelper(v int) *int { return &v } + +func TestEffectiveLoadFactor_NilAccount(t *testing.T) { + var a *Account + require.Equal(t, 1, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_NilLoadFactor_PositiveConcurrency(t *testing.T) { + a := &Account{Concurrency: 5} + require.Equal(t, 5, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_NilLoadFactor_ZeroConcurrency(t *testing.T) { + a := &Account{Concurrency: 0} + require.Equal(t, 1, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_PositiveLoadFactor(t *testing.T) { + a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(20)} + require.Equal(t, 20, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_ZeroLoadFactor_FallbackToConcurrency(t *testing.T) { + a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(0)} + require.Equal(t, 5, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_NegativeLoadFactor_FallbackToConcurrency(t *testing.T) { + a := &Account{Concurrency: 3, LoadFactor: intPtrHelper(-1)} + require.Equal(t, 3, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_ZeroLoadFactor_ZeroConcurrency(t *testing.T) { + a := &Account{Concurrency: 0, LoadFactor: intPtrHelper(0)} + require.Equal(t, 1, a.EffectiveLoadFactor()) +} diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go index a85c68ec..50c2b7cb 100644 --- a/backend/internal/service/account_openai_passthrough_test.go +++ b/backend/internal/service/account_openai_passthrough_test.go @@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) { } func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { - t.Run("default fallback to shared", func(t *testing.T) { + t.Run("default fallback to ctx_pool", func(t *testing.T) { account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}, } - require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("")) - require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("")) + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) }) t.Run("oauth mode field has highest priority", func(t *testing.T) { @@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{ - "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, "openai_oauth_responses_websockets_v2_enabled": false, "responses_websockets_v2_enabled": false, }, } - require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool)) }) - t.Run("legacy enabled maps to shared", func(t *testing.T) { + t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) { account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeAPIKey, @@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { "responses_websockets_v2_enabled": true, }, } - require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + }) + + t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) { + shared := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared, + }, + } + dedicated := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared)) + require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated)) }) t.Run("legacy disabled maps to off", func(t *testing.T) { @@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { "responses_websockets_v2_enabled": true, }, } - require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool)) }) t.Run("non openai always off", func(t *testing.T) { diff --git a/backend/internal/service/account_pool_mode_test.go b/backend/internal/service/account_pool_mode_test.go new file mode 100644 index 00000000..98429bb1 --- /dev/null +++ b/backend/internal/service/account_pool_mode_test.go @@ -0,0 +1,117 @@ +//go:build unit + +package service + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetPoolModeRetryCount(t *testing.T) { + tests := []struct { + name string + account *Account + expected int + }{ + { + name: "default_when_not_pool_mode", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{}, + }, + expected: defaultPoolModeRetryCount, + }, + { + name: "default_when_missing_retry_count", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + }, + }, + expected: defaultPoolModeRetryCount, + }, + { + name: "supports_float64_from_json_credentials", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "pool_mode_retry_count": float64(5), + }, + }, + expected: 5, + }, + { + name: "supports_json_number", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "pool_mode_retry_count": json.Number("4"), + }, + }, + expected: 4, + }, + { + name: "supports_string_value", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "pool_mode_retry_count": "2", + }, + }, + expected: 2, + }, + { + name: "negative_value_is_clamped_to_zero", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "pool_mode_retry_count": -1, + }, + }, + expected: 0, + }, + { + name: "oversized_value_is_clamped_to_max", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "pool_mode_retry_count": 99, + }, + }, + expected: maxPoolModeRetryCount, + }, + { + name: "invalid_value_falls_back_to_default", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "pool_mode_retry_count": "oops", + }, + }, + expected: defaultPoolModeRetryCount, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.account.GetPoolModeRetryCount()) + }) + } +} diff --git a/backend/internal/service/account_quota_reset_test.go b/backend/internal/service/account_quota_reset_test.go new file mode 100644 index 00000000..45a4bad6 --- /dev/null +++ b/backend/internal/service/account_quota_reset_test.go @@ -0,0 +1,516 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// nextFixedDailyReset +// --------------------------------------------------------------------------- + +func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) { + tz := time.UTC + // 2026-03-14 06:00 UTC, reset hour = 9 + after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz) + got := nextFixedDailyReset(9, tz, after) + want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedDailyReset_AtResetHour(t *testing.T) { + tz := time.UTC + // Exactly at reset hour → should return tomorrow + after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz) + got := nextFixedDailyReset(9, tz, after) + want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedDailyReset_AfterResetHour(t *testing.T) { + tz := time.UTC + // After reset hour → should return tomorrow + after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz) + got := nextFixedDailyReset(9, tz, after) + want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedDailyReset_MidnightReset(t *testing.T) { + tz := time.UTC + // Reset at hour 0 (midnight), currently 23:59 + after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz) + got := nextFixedDailyReset(0, tz, after) + want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) { + tz, err := time.LoadLocation("Asia/Shanghai") + require.NoError(t, err) + + // 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST) + after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC) + got := nextFixedDailyReset(9, tz, after) + // Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC + want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +// --------------------------------------------------------------------------- +// lastFixedDailyReset +// --------------------------------------------------------------------------- + +func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) { + tz := time.UTC + now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz) + got := lastFixedDailyReset(9, tz, now) + // Before today's 9:00 → yesterday 9:00 + want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestLastFixedDailyReset_AtResetHour(t *testing.T) { + tz := time.UTC + now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz) + got := lastFixedDailyReset(9, tz, now) + // At exactly 9:00 → today 9:00 + want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestLastFixedDailyReset_AfterResetHour(t *testing.T) { + tz := time.UTC + now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz) + got := lastFixedDailyReset(9, tz, now) + // After 9:00 → today 9:00 + want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +// --------------------------------------------------------------------------- +// nextFixedWeeklyReset +// --------------------------------------------------------------------------- + +func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) { + tz := time.UTC + // 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9 + after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz) + got := nextFixedWeeklyReset(1, 9, tz, after) + // Next Monday = 2026-03-16 + want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) { + tz := time.UTC + // 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00 + after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz) + got := nextFixedWeeklyReset(1, 9, tz, after) + // Today at 9:00 + want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) { + tz := time.UTC + // 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00 + after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz) + got := nextFixedWeeklyReset(1, 9, tz, after) + // Next Monday at 9:00 + want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) { + tz := time.UTC + // 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00 + after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz) + got := nextFixedWeeklyReset(1, 9, tz, after) + // Next Monday at 9:00 + want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) { + tz := time.UTC + // 2026-03-18 is Wednesday (day=3), target = Monday (day=1) + after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz) + got := nextFixedWeeklyReset(1, 9, tz, after) + // Next Monday = 2026-03-23 + want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedWeeklyReset_Sunday(t *testing.T) { + tz := time.UTC + // 2026-03-14 is Saturday (day=6), target = Sunday (day=0) + after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz) + got := nextFixedWeeklyReset(0, 0, tz, after) + // Next Sunday = 2026-03-15 + want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +// --------------------------------------------------------------------------- +// lastFixedWeeklyReset +// --------------------------------------------------------------------------- + +func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) { + tz := time.UTC + // 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00 + now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz) + got := lastFixedWeeklyReset(1, 9, tz, now) + // Today at 9:00 + want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) { + tz := time.UTC + // 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00 + now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz) + got := lastFixedWeeklyReset(1, 9, tz, now) + // Last Monday at 9:00 = 2026-03-09 + want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) { + tz := time.UTC + // 2026-03-18 is Wednesday (day=3), target = Monday (day=1) + now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz) + got := lastFixedWeeklyReset(1, 9, tz, now) + // Last Monday = 2026-03-16 + want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +// --------------------------------------------------------------------------- +// isFixedDailyPeriodExpired +// --------------------------------------------------------------------------- + +func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + }} + assert.True(t, a.isFixedDailyPeriodExpired(time.Time{})) +} + +func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + }} + // Period started after the most recent reset → not expired + // (This test uses a time very close to "now", which is after the last reset) + periodStart := time.Now().Add(-1 * time.Minute) + assert.False(t, a.isFixedDailyPeriodExpired(periodStart)) +} + +func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + }} + // Period started 3 days ago → definitely expired + periodStart := time.Now().Add(-72 * time.Hour) + assert.True(t, a.isFixedDailyPeriodExpired(periodStart)) +} + +func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_reset_timezone": "Invalid/Timezone", + }} + // Invalid timezone falls back to UTC + periodStart := time.Now().Add(-72 * time.Hour) + assert.True(t, a.isFixedDailyPeriodExpired(periodStart)) +} + +// --------------------------------------------------------------------------- +// isFixedWeeklyPeriodExpired +// --------------------------------------------------------------------------- + +func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_weekly_reset_mode": "fixed", + "quota_weekly_reset_day": float64(1), + "quota_weekly_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + }} + assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{})) +} + +func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_weekly_reset_mode": "fixed", + "quota_weekly_reset_day": float64(1), + "quota_weekly_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + }} + // Period started 1 minute ago → not expired + periodStart := time.Now().Add(-1 * time.Minute) + assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart)) +} + +func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_weekly_reset_mode": "fixed", + "quota_weekly_reset_day": float64(1), + "quota_weekly_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + }} + // Period started 10 days ago → definitely expired + periodStart := time.Now().Add(-240 * time.Hour) + assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart)) +} + +// --------------------------------------------------------------------------- +// ValidateQuotaResetConfig +// --------------------------------------------------------------------------- + +func TestValidateQuotaResetConfig_NilExtra(t *testing.T) { + assert.NoError(t, ValidateQuotaResetConfig(nil)) +} + +func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) { + assert.NoError(t, ValidateQuotaResetConfig(map[string]any{})) +} + +func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_weekly_reset_mode": "fixed", + "quota_weekly_reset_day": float64(1), + "quota_weekly_reset_hour": float64(0), + "quota_reset_timezone": "Asia/Shanghai", + } + assert.NoError(t, ValidateQuotaResetConfig(extra)) +} + +func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "rolling", + "quota_weekly_reset_mode": "rolling", + } + assert.NoError(t, ValidateQuotaResetConfig(extra)) +} + +func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) { + extra := map[string]any{ + "quota_reset_timezone": "Not/A/Timezone", + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_reset_timezone") +} + +func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "invalid", + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_daily_reset_mode") +} + +func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_hour": float64(24), + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_daily_reset_hour") +} + +func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_hour": float64(-1), + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_daily_reset_hour") +} + +func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) { + extra := map[string]any{ + "quota_weekly_reset_mode": "unknown", + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_weekly_reset_mode") +} + +func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) { + extra := map[string]any{ + "quota_weekly_reset_day": float64(7), + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_weekly_reset_day") +} + +func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) { + extra := map[string]any{ + "quota_weekly_reset_day": float64(-1), + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_weekly_reset_day") +} + +func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) { + extra := map[string]any{ + "quota_weekly_reset_hour": float64(25), + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_weekly_reset_hour") +} + +func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) { + // All boundary values should be valid + extra := map[string]any{ + "quota_daily_reset_hour": float64(23), + "quota_weekly_reset_day": float64(0), // Sunday + "quota_weekly_reset_hour": float64(0), + "quota_reset_timezone": "UTC", + } + assert.NoError(t, ValidateQuotaResetConfig(extra)) + + extra2 := map[string]any{ + "quota_daily_reset_hour": float64(0), + "quota_weekly_reset_day": float64(6), // Saturday + "quota_weekly_reset_hour": float64(23), + } + assert.NoError(t, ValidateQuotaResetConfig(extra2)) +} + +// --------------------------------------------------------------------------- +// ComputeQuotaResetAt +// --------------------------------------------------------------------------- + +func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "rolling", + "quota_weekly_reset_mode": "rolling", + } + ComputeQuotaResetAt(extra) + _, hasDailyResetAt := extra["quota_daily_reset_at"] + _, hasWeeklyResetAt := extra["quota_weekly_reset_at"] + assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at") + assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at") +} + +func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "rolling", + "quota_weekly_reset_mode": "rolling", + "quota_daily_reset_at": "2026-03-14T09:00:00Z", + "quota_weekly_reset_at": "2026-03-16T09:00:00Z", + } + ComputeQuotaResetAt(extra) + _, hasDailyResetAt := extra["quota_daily_reset_at"] + _, hasWeeklyResetAt := extra["quota_weekly_reset_at"] + assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at") + assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at") +} + +func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + } + ComputeQuotaResetAt(extra) + resetAtStr, ok := extra["quota_daily_reset_at"].(string) + require.True(t, ok, "quota_daily_reset_at should be set") + + resetAt, err := time.Parse(time.RFC3339, resetAtStr) + require.NoError(t, err) + // Reset time should be in the future + assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future") + // Reset hour should be 9 UTC + assert.Equal(t, 9, resetAt.UTC().Hour()) +} + +func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) { + extra := map[string]any{ + "quota_weekly_reset_mode": "fixed", + "quota_weekly_reset_day": float64(1), // Monday + "quota_weekly_reset_hour": float64(0), + "quota_reset_timezone": "UTC", + } + ComputeQuotaResetAt(extra) + resetAtStr, ok := extra["quota_weekly_reset_at"].(string) + require.True(t, ok, "quota_weekly_reset_at should be set") + + resetAt, err := time.Parse(time.RFC3339, resetAtStr) + require.NoError(t, err) + // Reset time should be in the future + assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future") + // Reset day should be Monday + assert.Equal(t, time.Monday, resetAt.UTC().Weekday()) +} + +func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) { + tz, err := time.LoadLocation("Asia/Shanghai") + require.NoError(t, err) + + extra := map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_reset_timezone": "Asia/Shanghai", + } + ComputeQuotaResetAt(extra) + resetAtStr, ok := extra["quota_daily_reset_at"].(string) + require.True(t, ok) + + resetAt, err := time.Parse(time.RFC3339, resetAtStr) + require.NoError(t, err) + // In Shanghai timezone, the hour should be 9 + assert.Equal(t, 9, resetAt.In(tz).Hour()) +} + +func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(12), + } + ComputeQuotaResetAt(extra) + resetAtStr, ok := extra["quota_daily_reset_at"].(string) + require.True(t, ok) + + resetAt, err := time.Parse(time.RFC3339, resetAtStr) + require.NoError(t, err) + // Default timezone is UTC + assert.Equal(t, 12, resetAt.UTC().Hour()) +} + +func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(99), + "quota_reset_timezone": "UTC", + } + ComputeQuotaResetAt(extra) + resetAtStr, ok := extra["quota_daily_reset_at"].(string) + require.True(t, ok) + + resetAt, err := time.Parse(time.RFC3339, resetAtStr) + require.NoError(t, err) + // Invalid hour → clamped to 0 + assert.Equal(t, 0, resetAt.UTC().Hour()) +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index a3707184..a06d8048 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -54,6 +54,8 @@ type AccountRepository interface { ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) + ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) + ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error @@ -66,6 +68,10 @@ type AccountRepository interface { UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error UpdateExtra(ctx context.Context, id int64, updates map[string]any) error BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) + // IncrementQuotaUsed 原子递增 API Key 账号的配额用量(总/日/周) + IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error + // ResetQuotaUsed 重置 API Key 账号所有维度的配额用量为 0 + ResetQuotaUsed(ctx context.Context, id int64) error } // AccountBulkUpdate describes the fields that can be updated in a bulk operation. @@ -76,6 +82,7 @@ type AccountBulkUpdate struct { Concurrency *int Priority *int RateMultiplier *float64 + LoadFactor *int Status *string Schedulable *bool Credentials map[string]any diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index a466b68a..c96b436f 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -147,6 +147,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte panic("unexpected ListSchedulableByGroupIDAndPlatforms call") } +func (s *accountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + panic("unexpected ListSchedulableUngroupedByPlatform call") +} + +func (s *accountRepoStub) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + panic("unexpected ListSchedulableUngroupedByPlatforms call") +} + func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { panic("unexpected SetRateLimited call") } @@ -191,6 +199,14 @@ func (s *accountRepoStub) BulkUpdate(ctx context.Context, ids []int64, updates A panic("unexpected BulkUpdate call") } +func (s *accountRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return nil +} + +func (s *accountRepoStub) ResetQuotaUsed(ctx context.Context, id int64) error { + return nil +} + // TestAccountService_Delete_NotFound 测试删除不存在的账号时返回正确的错误。 // 预期行为: // - ExistsByID 返回 false(账号不存在) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index c55e418d..482d22b1 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -12,6 +12,7 @@ import ( "io" "log" "net/http" + "net/http/httptest" "net/url" "regexp" "strings" @@ -33,7 +34,7 @@ import ( var sseDataPrefix = regexp.MustCompile(`^data:\s*`) const ( - testClaudeAPIURL = "https://api.anthropic.com/v1/messages" + testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions" @@ -44,16 +45,23 @@ const ( // TestEvent represents a SSE event for account testing type TestEvent struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Model string `json:"model,omitempty"` - Status string `json:"status,omitempty"` - Code string `json:"code,omitempty"` - Data any `json:"data,omitempty"` - Success bool `json:"success,omitempty"` - Error string `json:"error,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + Model string `json:"model,omitempty"` + Status string `json:"status,omitempty"` + Code string `json:"code,omitempty"` + ImageURL string `json:"image_url,omitempty"` + MimeType string `json:"mime_type,omitempty"` + Data any `json:"data,omitempty"` + Success bool `json:"success,omitempty"` + Error string `json:"error,omitempty"` } +const ( + defaultGeminiTextTestPrompt = "hi" + defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background." +) + // AccountTestService handles account testing operations type AccountTestService struct { accountRepo AccountRepository @@ -160,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) { // TestAccountConnection tests an account's connection by sending a test request // All account types use full Claude Code client characteristics, only auth header differs // modelID is optional - if empty, defaults to claude.DefaultTestModel -func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error { +func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error { ctx := c.Request.Context() // Get account @@ -175,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int } if account.IsGemini() { - return s.testGeminiAccountConnection(c, account, modelID) + return s.testGeminiAccountConnection(c, account, modelID, prompt) } if account.Platform == PlatformAntigravity { - return s.testAntigravityAccountConnection(c, account, modelID) + return s.routeAntigravityTest(c, account, modelID, prompt) } if account.Platform == PlatformSora { @@ -199,14 +207,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account testModelID = claude.DefaultTestModel } - // For API Key accounts with model mapping, map the model + // API Key 账号测试连接时也需要应用通配符模型映射。 if account.Type == "apikey" { - mapping := account.GetModelMapping() - if len(mapping) > 0 { - if mappedModel, exists := mapping[testModelID]; exists { - testModelID = mappedModel - } - } + testModelID = account.GetMappedModel(testModelID) + } + + // Bedrock accounts use a separate test path + if account.IsBedrock() { + return s.testBedrockAccountConnection(c, ctx, account, testModelID) } // Determine authentication method and API URL @@ -238,7 +246,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) } - apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages" + apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages?beta=true" } else { return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) } @@ -304,6 +312,109 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account return s.processClaudeStream(c, resp.Body) } +// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke +func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error { + region := bedrockRuntimeRegion(account) + resolvedModelID, ok := ResolveBedrockModelID(account, testModelID) + if !ok { + return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Bedrock model: %s", testModelID)) + } + testModelID = resolvedModelID + + // Set SSE headers (test UI expects SSE) + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + // Create a minimal Bedrock-compatible payload (no stream, no cache_control) + bedrockPayload := map[string]any{ + "anthropic_version": "bedrock-2023-05-31", + "messages": []map[string]any{ + { + "role": "user", + "content": []map[string]any{ + { + "type": "text", + "text": "hi", + }, + }, + }, + }, + "max_tokens": 256, + "temperature": 1, + } + bedrockBody, _ := json.Marshal(bedrockPayload) + + // Use non-streaming endpoint (response is standard Claude JSON) + apiURL := BuildBedrockURL(region, testModelID, false) + + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bedrockBody)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + req.Header.Set("Content-Type", "application/json") + + // Sign or set auth based on account type + if account.IsBedrockAPIKey() { + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return s.sendErrorAndEnd(c, "No API key available") + } + req.Header.Set("Authorization", "Bearer "+apiKey) + } else { + signer, err := NewBedrockSignerFromAccount(account) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Bedrock signer: %s", err.Error())) + } + if err := signer.SignRequest(ctx, req, bedrockBody); err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to sign request: %s", err.Error())) + } + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, false) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) + } + + // Bedrock non-streaming response is standard Claude JSON, extract the text + var result struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + } + if err := json.Unmarshal(body, &result); err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error())) + } + + text := "" + if len(result.Content) > 0 { + text = result.Content[0].Text + } + if text == "" { + text = "(empty response)" + } + + s.sendEvent(c, TestEvent{Type: "content", Text: text}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + // testOpenAIAccountConnection tests an OpenAI account's connection func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error { ctx := c.Request.Context() @@ -405,8 +516,27 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account } defer func() { _ = resp.Body.Close() }() + if isOAuth && s.accountRepo != nil { + if updates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(updates) > 0 { + _ = s.accountRepo.UpdateExtra(ctx, account.ID, updates) + mergeAccountExtra(account, updates) + } + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil { + _ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt) + account.RateLimitResetAt = resetAt + } + } + } + if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) + if isOAuth && s.accountRepo != nil { + if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil { + _ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt) + account.RateLimitResetAt = resetAt + } + } return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) } @@ -415,7 +545,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account } // testGeminiAccountConnection tests a Gemini account's connection -func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error { +func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error { ctx := c.Request.Context() // Determine the model to use @@ -442,7 +572,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account c.Writer.Flush() // Create test payload (Gemini format) - payload := createGeminiTestPayload() + payload := createGeminiTestPayload(testModelID, prompt) // Build request based on account type var req *http.Request @@ -1176,6 +1306,18 @@ func truncateSoraErrorBody(body []byte, max int) string { return soraerror.TruncateBody(body, max) } +// routeAntigravityTest 路由 Antigravity 账号的测试请求。 +// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。 +func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error { + if account.Type == AccountTypeAPIKey { + if strings.HasPrefix(modelID, "gemini-") { + return s.testGeminiAccountConnection(c, account, modelID, prompt) + } + return s.testClaudeAccountConnection(c, account, modelID) + } + return s.testAntigravityAccountConnection(c, account, modelID) +} + // testAntigravityAccountConnection tests an Antigravity account's connection // 支持 Claude 和 Gemini 两种协议,使用非流式请求 func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { @@ -1317,14 +1459,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT return req, nil } -// createGeminiTestPayload creates a minimal test payload for Gemini API -func createGeminiTestPayload() []byte { +// createGeminiTestPayload creates a minimal test payload for Gemini API. +// Image models use the image-generation path so the frontend can preview the returned image. +func createGeminiTestPayload(modelID string, prompt string) []byte { + if isImageGenerationModel(modelID) { + imagePrompt := strings.TrimSpace(prompt) + if imagePrompt == "" { + imagePrompt = defaultGeminiImageTestPrompt + } + + payload := map[string]any{ + "contents": []map[string]any{ + { + "role": "user", + "parts": []map[string]any{ + {"text": imagePrompt}, + }, + }, + }, + "generationConfig": map[string]any{ + "responseModalities": []string{"TEXT", "IMAGE"}, + "imageConfig": map[string]any{ + "aspectRatio": "1:1", + }, + }, + } + bytes, _ := json.Marshal(payload) + return bytes + } + + textPrompt := strings.TrimSpace(prompt) + if textPrompt == "" { + textPrompt = defaultGeminiTextTestPrompt + } + payload := map[string]any{ "contents": []map[string]any{ { "role": "user", "parts": []map[string]any{ - {"text": "hi"}, + {"text": textPrompt}, }, }, }, @@ -1384,6 +1558,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) if text, ok := partMap["text"].(string); ok && text != "" { s.sendEvent(c, TestEvent{Type: "content", Text: text}) } + if inlineData, ok := partMap["inlineData"].(map[string]any); ok { + mimeType, _ := inlineData["mimeType"].(string) + data, _ := inlineData["data"].(string) + if strings.HasPrefix(strings.ToLower(mimeType), "image/") && data != "" { + s.sendEvent(c, TestEvent{ + Type: "image", + ImageURL: fmt.Sprintf("data:%s;base64,%s", mimeType, data), + MimeType: mimeType, + }) + } + } } } } @@ -1560,3 +1745,62 @@ func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) er s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg}) return fmt.Errorf("%s", errorMsg) } + +// RunTestBackground executes an account test in-memory (no real HTTP client), +// capturing SSE output via httptest.NewRecorder, then parses the result. +func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID int64, modelID string) (*ScheduledTestResult, error) { + startedAt := time.Now() + + w := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(w) + ginCtx.Request = (&http.Request{}).WithContext(ctx) + + testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "") + + finishedAt := time.Now() + body := w.Body.String() + responseText, errMsg := parseTestSSEOutput(body) + + status := "success" + if testErr != nil || errMsg != "" { + status = "failed" + if errMsg == "" && testErr != nil { + errMsg = testErr.Error() + } + } + + return &ScheduledTestResult{ + Status: status, + ResponseText: responseText, + ErrorMessage: errMsg, + LatencyMs: finishedAt.Sub(startedAt).Milliseconds(), + StartedAt: startedAt, + FinishedAt: finishedAt, + }, nil +} + +// parseTestSSEOutput extracts response text and error message from captured SSE output. +func parseTestSSEOutput(body string) (responseText, errMsg string) { + var texts []string + for _, line := range strings.Split(body, "\n") { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data: ") { + continue + } + jsonStr := strings.TrimPrefix(line, "data: ") + var event TestEvent + if err := json.Unmarshal([]byte(jsonStr), &event); err != nil { + continue + } + switch event.Type { + case "content": + if event.Text != "" { + texts = append(texts, event.Text) + } + case "error": + errMsg = event.Error + } + } + responseText = strings.Join(texts, "") + return +} diff --git a/backend/internal/service/account_test_service_gemini_test.go b/backend/internal/service/account_test_service_gemini_test.go new file mode 100644 index 00000000..5ba04c69 --- /dev/null +++ b/backend/internal/service/account_test_service_gemini_test.go @@ -0,0 +1,59 @@ +//go:build unit + +package service + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestCreateGeminiTestPayload_ImageModel(t *testing.T) { + t.Parallel() + + payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot") + + var parsed struct { + Contents []struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + } `json:"contents"` + GenerationConfig struct { + ResponseModalities []string `json:"responseModalities"` + ImageConfig struct { + AspectRatio string `json:"aspectRatio"` + } `json:"imageConfig"` + } `json:"generationConfig"` + } + + require.NoError(t, json.Unmarshal(payload, &parsed)) + require.Len(t, parsed.Contents, 1) + require.Len(t, parsed.Contents[0].Parts, 1) + require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text) + require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities) + require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio) +} + +func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + ctx, recorder := newSoraTestContext() + svc := &AccountTestService{} + + stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n") + + err := svc.processGeminiStream(ctx, stream) + require.NoError(t, err) + + body := recorder.Body.String() + require.Contains(t, body, "\"type\":\"content\"") + require.Contains(t, body, "\"text\":\"ok\"") + require.Contains(t, body, "\"type\":\"image\"") + require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"") + require.Contains(t, body, "\"mime_type\":\"image/png\"") +} diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go new file mode 100644 index 00000000..efa6f7da --- /dev/null +++ b/backend/internal/service/account_test_service_openai_test.go @@ -0,0 +1,102 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type openAIAccountTestRepo struct { + mockAccountRepoForGemini + updatedExtra map[string]any + rateLimitedID int64 + rateLimitedAt *time.Time +} + +func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + r.updatedExtra = updates + return nil +} + +func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error { + r.rateLimitedID = id + r.rateLimitedAt = &resetAt + return nil +} + +func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, recorder := newSoraTestContext() + + resp := newJSONResponse(http.StatusOK, "") + resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"} + +`)) + resp.Header.Set("x-codex-primary-used-percent", "88") + resp.Header.Set("x-codex-primary-reset-after-seconds", "604800") + resp.Header.Set("x-codex-primary-window-minutes", "10080") + resp.Header.Set("x-codex-secondary-used-percent", "42") + resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000") + resp.Header.Set("x-codex-secondary-window-minutes", "300") + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 89, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + require.NoError(t, err) + require.NotEmpty(t, repo.updatedExtra) + require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"]) + require.Equal(t, 88.0, repo.updatedExtra["codex_7d_used_percent"]) + require.Contains(t, recorder.Body.String(), "test_complete") +} + +func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newSoraTestContext() + + resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`) + resp.Header.Set("x-codex-primary-used-percent", "100") + resp.Header.Set("x-codex-primary-reset-after-seconds", "604800") + resp.Header.Set("x-codex-primary-window-minutes", "10080") + resp.Header.Set("x-codex-secondary-used-percent", "100") + resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000") + resp.Header.Set("x-codex-secondary-window-minutes", "300") + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 88, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + require.Error(t, err) + require.NotEmpty(t, repo.updatedExtra) + require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"]) + require.Equal(t, int64(88), repo.rateLimitedID) + require.NotNil(t, repo.rateLimitedAt) + require.NotNil(t, account.RateLimitResetAt) + if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil { + require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second) + } +} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 6dee6c13..f117abfd 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -1,17 +1,25 @@ package service import ( + "bytes" "context" + "encoding/json" "fmt" "log" + "log/slog" + "math/rand/v2" + "net/http" "strings" "sync" "time" + httppool "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + openaipkg "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "golang.org/x/sync/errgroup" + "golang.org/x/sync/singleflight" ) type UsageLogRepository interface { @@ -37,9 +45,12 @@ type UsageLogRepository interface { GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) + GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) + GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) + GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) @@ -70,8 +81,10 @@ type accountWindowStatsBatchReader interface { } // apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) +// 同时支持缓存错误响应(负缓存),防止 429 等错误导致的重试风暴 type apiUsageCache struct { response *ClaudeUsageResponse + err error // 非 nil 表示缓存的错误(负缓存) timestamp time.Time } @@ -88,15 +101,23 @@ type antigravityUsageCache struct { } const ( - apiCacheTTL = 3 * time.Minute - windowStatsCacheTTL = 1 * time.Minute + apiCacheTTL = 3 * time.Minute + apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟 + antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误) + apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟 + windowStatsCacheTTL = 1 * time.Minute + openAIProbeCacheTTL = 10 * time.Minute + openAICodexProbeVersion = "0.104.0" ) // UsageCache 封装账户使用量相关的缓存 type UsageCache struct { - apiCache sync.Map // accountID -> *apiUsageCache - windowStatsCache sync.Map // accountID -> *windowStatsCache - antigravityCache sync.Map // accountID -> *antigravityUsageCache + apiCache sync.Map // accountID -> *apiUsageCache + windowStatsCache sync.Map // accountID -> *windowStatsCache + antigravityCache sync.Map // accountID -> *antigravityUsageCache + apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic) + antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存 + openAIProbeCache sync.Map // accountID -> time.Time } // NewUsageCache 创建 UsageCache 实例 @@ -133,6 +154,25 @@ type AntigravityModelQuota struct { ResetTime string `json:"reset_time"` // 重置时间 ISO8601 } +// AntigravityModelDetail Antigravity 单个模型的详细能力信息 +type AntigravityModelDetail struct { + DisplayName string `json:"display_name,omitempty"` + SupportsImages *bool `json:"supports_images,omitempty"` + SupportsThinking *bool `json:"supports_thinking,omitempty"` + ThinkingBudget *int `json:"thinking_budget,omitempty"` + Recommended *bool `json:"recommended,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"` +} + +// AICredit 表示 Antigravity 账号的 AI Credits 余额信息。 +type AICredit struct { + CreditType string `json:"credit_type,omitempty"` + Amount float64 `json:"amount,omitempty"` + MinimumBalance float64 `json:"minimum_balance,omitempty"` +} + // UsageInfo 账号使用量信息 type UsageInfo struct { UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 @@ -148,6 +188,36 @@ type UsageInfo struct { // Antigravity 多模型配额 AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"` + + // Antigravity 账号级信息 + SubscriptionTier string `json:"subscription_tier,omitempty"` // 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN + SubscriptionTierRaw string `json:"subscription_tier_raw,omitempty"` // 上游原始订阅等级名称 + + // Antigravity 模型详细能力信息(与 antigravity_quota 同 key) + AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"` + + // Antigravity AI Credits 余额 + AICredits []AICredit `json:"ai_credits,omitempty"` + + // Antigravity 废弃模型转发规则 (old_model_id -> new_model_id) + ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"` + + // Antigravity 账号是否被上游禁止 (HTTP 403) + IsForbidden bool `json:"is_forbidden,omitempty"` + ForbiddenReason string `json:"forbidden_reason,omitempty"` + ForbiddenType string `json:"forbidden_type,omitempty"` // "validation" / "violation" / "forbidden" + ValidationURL string `json:"validation_url,omitempty"` // 验证/申诉链接 + + // 状态标记(从 ForbiddenType / HTTP 错误码推导) + NeedsVerify bool `json:"needs_verify,omitempty"` // 需要人工验证(forbidden_type=validation) + IsBanned bool `json:"is_banned,omitempty"` // 账号被封(forbidden_type=violation) + NeedsReauth bool `json:"needs_reauth,omitempty"` // token 失效需重新授权(401) + + // 错误码(机器可读):forbidden / unauthenticated / rate_limited / network_error + ErrorCode string `json:"error_code,omitempty"` + + // 获取 usage 时的错误信息(降级返回,而非 500) + Error string `json:"error,omitempty"` } // ClaudeUsageResponse Anthropic API返回的usage结构 @@ -224,6 +294,14 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U return nil, fmt.Errorf("get account failed: %w", err) } + if account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth { + usage, err := s.getOpenAIUsage(ctx, account) + if err == nil { + s.tryClearRecoverableAccountError(ctx, account) + } + return usage, err + } + if account.Platform == PlatformGemini { usage, err := s.getGeminiUsage(ctx, account) if err == nil { @@ -245,24 +323,65 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U if account.CanGetUsage() { var apiResp *ClaudeUsageResponse - // 1. 检查 API 缓存(10 分钟) + // 1. 检查缓存(成功响应 3 分钟 / 错误响应 1 分钟) if cached, ok := s.cache.apiCache.Load(accountID); ok { - if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { - apiResp = cache.response + if cache, ok := cached.(*apiUsageCache); ok { + age := time.Since(cache.timestamp) + if cache.err != nil && age < apiErrorCacheTTL { + // 负缓存命中:返回缓存的错误,避免重试风暴 + return nil, cache.err + } + if cache.response != nil && age < apiCacheTTL { + apiResp = cache.response + } } } - // 2. 如果没有缓存,从 API 获取 + // 2. 如果没有有效缓存,通过 singleflight 从 API 获取(防止并发击穿) if apiResp == nil { - apiResp, err = s.fetchOAuthUsageRaw(ctx, account) - if err != nil { - return nil, err + // 随机延迟:打散多账号并发请求,避免同一时刻大量相同 TLS 指纹请求 + // 触发上游反滥用检测。延迟范围 0~800ms,仅在缓存未命中时生效。 + jitter := time.Duration(rand.Int64N(int64(apiQueryMaxJitter))) + select { + case <-time.After(jitter): + case <-ctx.Done(): + return nil, ctx.Err() } - // 缓存 API 响应 - s.cache.apiCache.Store(accountID, &apiUsageCache{ - response: apiResp, - timestamp: time.Now(), + + flightKey := fmt.Sprintf("usage:%d", accountID) + result, flightErr, _ := s.cache.apiFlight.Do(flightKey, func() (any, error) { + // 再次检查缓存(可能在等待 singleflight 期间被其他请求填充) + if cached, ok := s.cache.apiCache.Load(accountID); ok { + if cache, ok := cached.(*apiUsageCache); ok { + age := time.Since(cache.timestamp) + if cache.err != nil && age < apiErrorCacheTTL { + return nil, cache.err + } + if cache.response != nil && age < apiCacheTTL { + return cache.response, nil + } + } + } + resp, fetchErr := s.fetchOAuthUsageRaw(ctx, account) + if fetchErr != nil { + // 负缓存:缓存错误响应,防止后续请求重复触发 429 + s.cache.apiCache.Store(accountID, &apiUsageCache{ + err: fetchErr, + timestamp: time.Now(), + }) + return nil, fetchErr + } + // 缓存成功响应 + s.cache.apiCache.Store(accountID, &apiUsageCache{ + response: resp, + timestamp: time.Now(), + }) + return resp, nil }) + if flightErr != nil { + return nil, flightErr + } + apiResp, _ = result.(*ClaudeUsageResponse) } // 3. 构建 UsageInfo(每次都重新计算 RemainingSeconds) @@ -288,6 +407,237 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U return nil, fmt.Errorf("account type %s does not support usage query", account.Type) } +func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) { + now := time.Now() + usage := &UsageInfo{UpdatedAt: &now} + + if account == nil { + return usage, nil + } + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now) + + if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil { + usage.FiveHour = progress + } + if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil { + usage.SevenDay = progress + } + + if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) { + if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) { + mergeAccountExtra(account, updates) + if resetAt != nil { + account.RateLimitResetAt = resetAt + } + if usage.UpdatedAt == nil { + usage.UpdatedAt = &now + } + if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil { + usage.FiveHour = progress + } + if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil { + usage.SevenDay = progress + } + } + } + + if s.usageLogRepo == nil { + return usage, nil + } + + if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil { + windowStats := windowStatsFromAccountStats(stats) + if hasMeaningfulWindowStats(windowStats) { + if usage.FiveHour == nil { + usage.FiveHour = &UsageProgress{Utilization: 0} + } + usage.FiveHour.WindowStats = windowStats + } + } + + if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil { + windowStats := windowStatsFromAccountStats(stats) + if hasMeaningfulWindowStats(windowStats) { + if usage.SevenDay == nil { + usage.SevenDay = &UsageProgress{Utilization: 0} + } + usage.SevenDay.WindowStats = windowStats + } + } + + return usage, nil +} + +func shouldRefreshOpenAICodexSnapshot(account *Account, usage *UsageInfo, now time.Time) bool { + if account == nil { + return false + } + if usage == nil { + return true + } + if usage.FiveHour == nil || usage.SevenDay == nil { + return true + } + if account.IsRateLimited() { + return true + } + return isOpenAICodexSnapshotStale(account, now) +} + +func isOpenAICodexSnapshotStale(account *Account, now time.Time) bool { + if account == nil || !account.IsOpenAIOAuth() || !account.IsOpenAIResponsesWebSocketV2Enabled() { + return false + } + if account.Extra == nil { + return true + } + raw, ok := account.Extra["codex_usage_updated_at"] + if !ok { + return true + } + ts, err := parseTime(fmt.Sprint(raw)) + if err != nil { + return true + } + return now.Sub(ts) >= openAIProbeCacheTTL +} + +func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool { + if s == nil || s.cache == nil || accountID <= 0 { + return true + } + if cached, ok := s.cache.openAIProbeCache.Load(accountID); ok { + if ts, ok := cached.(time.Time); ok && now.Sub(ts) < openAIProbeCacheTTL { + return false + } + } + s.cache.openAIProbeCache.Store(accountID, now) + return true +} + +func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) { + if account == nil || !account.IsOAuth() { + return nil, nil, nil + } + accessToken := account.GetOpenAIAccessToken() + if accessToken == "" { + return nil, nil, fmt.Errorf("no access token available") + } + modelID := openaipkg.DefaultTestModel + payload := createOpenAITestPayload(modelID, true) + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err) + } + + reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes)) + if err != nil { + return nil, nil, fmt.Errorf("create openai probe request: %w", err) + } + req.Host = "chatgpt.com" + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("OpenAI-Beta", "responses=experimental") + req.Header.Set("Originator", "codex_cli_rs") + req.Header.Set("Version", openAICodexProbeVersion) + req.Header.Set("User-Agent", codexCLIUserAgent) + if s.identityCache != nil { + if fp, fpErr := s.identityCache.GetFingerprint(reqCtx, account.ID); fpErr == nil && fp != nil && strings.TrimSpace(fp.UserAgent) != "" { + req.Header.Set("User-Agent", strings.TrimSpace(fp.UserAgent)) + } + } + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + client, err := httppool.GetClient(httppool.Options{ + ProxyURL: proxyURL, + Timeout: 15 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + }) + if err != nil { + return nil, nil, fmt.Errorf("build openai probe client: %w", err) + } + resp, err := client.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp) + if err != nil { + return nil, nil, err + } + if len(updates) > 0 || resetAt != nil { + s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt) + return updates, resetAt, nil + } + return nil, nil, nil +} + +func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) { + if s == nil || s.accountRepo == nil || accountID <= 0 { + return + } + if len(updates) == 0 && resetAt == nil { + return + } + + go func() { + updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer updateCancel() + if len(updates) > 0 { + _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + } + if resetAt != nil { + _ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt) + } + }() +} + +func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) { + if resp == nil { + return nil, nil, nil + } + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + baseTime := time.Now() + updates := buildCodexUsageExtraUpdates(snapshot, baseTime) + resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime) + if len(updates) > 0 { + return updates, resetAt, nil + } + return nil, resetAt, nil + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) + } + return nil, nil, nil +} + +func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) { + updates, _, err := extractOpenAICodexProbeSnapshot(resp) + return updates, err +} + +func mergeAccountExtra(account *Account, updates map[string]any) { + if account == nil || len(updates) == 0 { + return + } + if account.Extra == nil { + account.Extra = make(map[string]any, len(updates)) + } + for k, v := range updates { + account.Extra[k] = v + } +} + func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) { now := time.Now() usage := &UsageInfo{ @@ -352,34 +702,157 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account * return &UsageInfo{UpdatedAt: &now}, nil } - // 1. 检查缓存(10 分钟) + // 1. 检查缓存 if cached, ok := s.cache.antigravityCache.Load(account.ID); ok { - if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { - // 重新计算 RemainingSeconds - usage := cache.usageInfo - if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil { - usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds()) + if cache, ok := cached.(*antigravityUsageCache); ok { + ttl := antigravityCacheTTL(cache.usageInfo) + if time.Since(cache.timestamp) < ttl { + usage := cache.usageInfo + if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil { + usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds()) + } + return usage, nil } - return usage, nil } } - // 2. 获取代理 URL - proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account) + // 2. singleflight 防止并发击穿 + flightKey := fmt.Sprintf("ag-usage:%d", account.ID) + result, flightErr, _ := s.cache.antigravityFlight.Do(flightKey, func() (any, error) { + // 再次检查缓存(等待期间可能已被填充) + if cached, ok := s.cache.antigravityCache.Load(account.ID); ok { + if cache, ok := cached.(*antigravityUsageCache); ok { + ttl := antigravityCacheTTL(cache.usageInfo) + if time.Since(cache.timestamp) < ttl { + usage := cache.usageInfo + // 重新计算 RemainingSeconds,避免返回过时的剩余秒数 + recalcAntigravityRemainingSeconds(usage) + return usage, nil + } + } + } - // 3. 调用 API 获取额度 - result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL) - if err != nil { - return nil, fmt.Errorf("fetch antigravity quota failed: %w", err) - } + // 使用独立 context,避免调用方 cancel 导致所有共享 flight 的请求失败 + fetchCtx, fetchCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer fetchCancel() - // 4. 缓存结果 - s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ - usageInfo: result.UsageInfo, - timestamp: time.Now(), + proxyURL := s.antigravityQuotaFetcher.GetProxyURL(fetchCtx, account) + fetchResult, err := s.antigravityQuotaFetcher.FetchQuota(fetchCtx, account, proxyURL) + if err != nil { + degraded := buildAntigravityDegradedUsage(err) + enrichUsageWithAccountError(degraded, account) + s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ + usageInfo: degraded, + timestamp: time.Now(), + }) + return degraded, nil + } + + enrichUsageWithAccountError(fetchResult.UsageInfo, account) + s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ + usageInfo: fetchResult.UsageInfo, + timestamp: time.Now(), + }) + return fetchResult.UsageInfo, nil }) - return result.UsageInfo, nil + if flightErr != nil { + return nil, flightErr + } + usage, ok := result.(*UsageInfo) + if !ok || usage == nil { + now := time.Now() + return &UsageInfo{UpdatedAt: &now}, nil + } + return usage, nil +} + +// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds +// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数 +func recalcAntigravityRemainingSeconds(info *UsageInfo) { + if info == nil { + return + } + if info.FiveHour != nil && info.FiveHour.ResetsAt != nil { + remaining := int(time.Until(*info.FiveHour.ResetsAt).Seconds()) + if remaining < 0 { + remaining = 0 + } + info.FiveHour.RemainingSeconds = remaining + } +} + +// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL +// 403 forbidden 状态稳定,缓存与成功相同(3 分钟); +// 其他错误(401/网络)可能快速恢复,缓存 1 分钟。 +func antigravityCacheTTL(info *UsageInfo) time.Duration { + if info == nil { + return antigravityErrorTTL + } + if info.IsForbidden { + return apiCacheTTL // 封号/验证状态不会很快变 + } + if info.ErrorCode != "" || info.Error != "" { + return antigravityErrorTTL + } + return apiCacheTTL +} + +// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo +func buildAntigravityDegradedUsage(err error) *UsageInfo { + now := time.Now() + errMsg := fmt.Sprintf("usage API error: %v", err) + slog.Warn("antigravity usage fetch failed, returning degraded response", "error", err) + + info := &UsageInfo{ + UpdatedAt: &now, + Error: errMsg, + } + + // 从错误信息推断 error_code 和状态标记 + // 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..." + errStr := err.Error() + switch { + case strings.Contains(errStr, "HTTP 401") || + strings.Contains(errStr, "UNAUTHENTICATED") || + strings.Contains(errStr, "invalid_grant"): + info.ErrorCode = errorCodeUnauthenticated + info.NeedsReauth = true + case strings.Contains(errStr, "HTTP 429"): + info.ErrorCode = errorCodeRateLimited + default: + info.ErrorCode = errorCodeNetworkError + } + + return info +} + +// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo +// 场景 1(成功路径):FetchAvailableModels 正常返回,但账号已因 403 被标记为 error, +// +// 需要在正常 usage 数据上附加 forbidden/validation 信息。 +// +// 场景 2(降级路径):被封号的账号 OAuth token 失效,FetchAvailableModels 返回 401, +// +// 降级逻辑设置了 needs_reauth,但账号实际是 403 封号/需验证,需覆盖为正确状态。 +func enrichUsageWithAccountError(info *UsageInfo, account *Account) { + if info == nil || account == nil || account.Status != StatusError { + return + } + msg := strings.ToLower(account.ErrorMessage) + if !strings.Contains(msg, "403") && !strings.Contains(msg, "forbidden") && + !strings.Contains(msg, "violation") && !strings.Contains(msg, "validation") { + return + } + fbType := classifyForbiddenType(account.ErrorMessage) + info.IsForbidden = true + info.ForbiddenType = fbType + info.ForbiddenReason = account.ErrorMessage + info.NeedsVerify = fbType == forbiddenTypeValidation + info.IsBanned = fbType == forbiddenTypeViolation + info.ValidationURL = extractValidationURL(account.ErrorMessage) + info.ErrorCode = errorCodeForbidden + info.NeedsReauth = false } // addWindowStats 为 usage 数据添加窗口期统计 @@ -519,6 +992,72 @@ func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats { } } +func hasMeaningfulWindowStats(stats *WindowStats) bool { + if stats == nil { + return false + } + return stats.Requests > 0 || stats.Tokens > 0 || stats.Cost > 0 || stats.StandardCost > 0 || stats.UserCost > 0 +} + +func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress { + if len(extra) == 0 { + return nil + } + + var ( + usedPercentKey string + resetAfterKey string + resetAtKey string + ) + + switch window { + case "5h": + usedPercentKey = "codex_5h_used_percent" + resetAfterKey = "codex_5h_reset_after_seconds" + resetAtKey = "codex_5h_reset_at" + case "7d": + usedPercentKey = "codex_7d_used_percent" + resetAfterKey = "codex_7d_reset_after_seconds" + resetAtKey = "codex_7d_reset_at" + default: + return nil + } + + usedRaw, ok := extra[usedPercentKey] + if !ok { + return nil + } + + progress := &UsageProgress{Utilization: parseExtraFloat64(usedRaw)} + if resetAtRaw, ok := extra[resetAtKey]; ok { + if resetAt, err := parseTime(fmt.Sprint(resetAtRaw)); err == nil { + progress.ResetsAt = &resetAt + progress.RemainingSeconds = int(time.Until(resetAt).Seconds()) + if progress.RemainingSeconds < 0 { + progress.RemainingSeconds = 0 + } + } + } + if progress.ResetsAt == nil { + if resetAfterSeconds := parseExtraInt(extra[resetAfterKey]); resetAfterSeconds > 0 { + base := now + if updatedAtRaw, ok := extra["codex_usage_updated_at"]; ok { + if updatedAt, err := parseTime(fmt.Sprint(updatedAtRaw)); err == nil { + base = updatedAt + } + } + resetAt := base.Add(time.Duration(resetAfterSeconds) * time.Second) + progress.ResetsAt = &resetAt + progress.RemainingSeconds = int(time.Until(resetAt).Seconds()) + if progress.RemainingSeconds < 0 { + progress.RemainingSeconds = 0 + } + } + } + + return progress +} + func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime) if err != nil { @@ -666,15 +1205,30 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn remaining = 0 } - // 根据状态估算使用率 (百分比形式,100 = 100%) + // 优先使用响应头中存储的真实 utilization 值(0-1 小数,转为 0-100 百分比) var utilization float64 - switch account.SessionWindowStatus { - case "rejected": - utilization = 100.0 - case "allowed_warning": - utilization = 80.0 - default: - utilization = 0.0 + var found bool + if stored, ok := account.Extra["session_window_utilization"]; ok { + switch v := stored.(type) { + case float64: + utilization = v * 100 + found = true + case json.Number: + if f, err := v.Float64(); err == nil { + utilization = f * 100 + found = true + } + } + } + + // 如果没有存储的 utilization,回退到状态估算 + if !found { + switch account.SessionWindowStatus { + case "rejected": + utilization = 100.0 + case "allowed_warning": + utilization = 80.0 + } } info.FiveHour = &UsageProgress{ diff --git a/backend/internal/service/account_usage_service_test.go b/backend/internal/service/account_usage_service_test.go new file mode 100644 index 00000000..a063fe26 --- /dev/null +++ b/backend/internal/service/account_usage_service_test.go @@ -0,0 +1,150 @@ +package service + +import ( + "context" + "net/http" + "testing" + "time" +) + +type accountUsageCodexProbeRepo struct { + stubOpenAIAccountRepo + updateExtraCh chan map[string]any + rateLimitCh chan time.Time +} + +func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + if r.updateExtraCh != nil { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtraCh <- copied + } + return nil +} + +func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + if r.rateLimitCh != nil { + r.rateLimitCh <- resetAt + } + return nil +} + +func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) { + t.Parallel() + + rateLimitedUntil := time.Now().Add(5 * time.Minute) + now := time.Now() + usage := &UsageInfo{ + FiveHour: &UsageProgress{Utilization: 0}, + SevenDay: &UsageProgress{Utilization: 0}, + } + + if !shouldRefreshOpenAICodexSnapshot(&Account{RateLimitResetAt: &rateLimitedUntil}, usage, now) { + t.Fatal("expected rate-limited account to force codex snapshot refresh") + } + + if shouldRefreshOpenAICodexSnapshot(&Account{}, usage, now) { + t.Fatal("expected complete non-rate-limited usage to skip codex snapshot refresh") + } + + if !shouldRefreshOpenAICodexSnapshot(&Account{}, &UsageInfo{FiveHour: nil, SevenDay: &UsageProgress{}}, now) { + t.Fatal("expected missing 5h snapshot to require refresh") + } + + staleAt := now.Add(-(openAIProbeCacheTTL + time.Minute)).Format(time.RFC3339) + if !shouldRefreshOpenAICodexSnapshot(&Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "codex_usage_updated_at": staleAt, + }, + }, usage, now) { + t.Fatal("expected stale ws snapshot to trigger refresh") + } +} + +func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) { + t.Parallel() + + headers := make(http.Header) + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "604800") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "18000") + headers.Set("x-codex-secondary-window-minutes", "300") + + updates, err := extractOpenAICodexProbeUpdates(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers}) + if err != nil { + t.Fatalf("extractOpenAICodexProbeUpdates() error = %v", err) + } + if len(updates) == 0 { + t.Fatal("expected codex probe updates from 429 headers") + } + if got := updates["codex_5h_used_percent"]; got != 100.0 { + t.Fatalf("codex_5h_used_percent = %v, want 100", got) + } + if got := updates["codex_7d_used_percent"]; got != 100.0 { + t.Fatalf("codex_7d_used_percent = %v, want 100", got) + } +} + +func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) { + t.Parallel() + + headers := make(http.Header) + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "604800") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "18000") + headers.Set("x-codex-secondary-window-minutes", "300") + + updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers}) + if err != nil { + t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err) + } + if len(updates) == 0 { + t.Fatal("expected codex probe updates from 429 headers") + } + if resetAt == nil { + t.Fatal("expected resetAt from exhausted codex headers") + } +} + +func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) { + t.Parallel() + + repo := &accountUsageCodexProbeRepo{ + updateExtraCh: make(chan map[string]any, 1), + rateLimitCh: make(chan time.Time, 1), + } + svc := &AccountUsageService{accountRepo: repo} + resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second) + + svc.persistOpenAICodexProbeSnapshot(321, map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.Format(time.RFC3339), + }, &resetAt) + + select { + case updates := <-repo.updateExtraCh: + if got := updates["codex_7d_used_percent"]; got != 100.0 { + t.Fatalf("codex_7d_used_percent = %v, want 100", got) + } + case <-time.After(2 * time.Second): + t.Fatal("waiting for codex probe extra persistence timed out") + } + + select { + case got := <-repo.rateLimitCh: + if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) { + t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt) + } + case <-time.After(2 * time.Second): + t.Fatal("waiting for codex probe rate limit persistence timed out") + } +} diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 7782f948..0d7ffffa 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) { } } -func TestMatchWildcardMapping(t *testing.T) { +func TestMatchWildcardMappingResult(t *testing.T) { tests := []struct { name string mapping map[string]string requestedModel string expected string + matched bool }{ // 精确匹配优先于通配符 { @@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-5-exact", + matched: true, }, // 最长通配符优先 @@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-series", + matched: true, }, // 单个通配符 @@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "claude-opus-4-5", expected: "claude-mapped", + matched: true, }, // 无匹配返回原始模型 @@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "gemini-3-flash", expected: "gemini-3-flash", + matched: false, }, // 空映射返回原始模型 @@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) { mapping: map[string]string{}, requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-5", + matched: false, }, // Gemini 模型映射 @@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "gemini-3-flash-preview", expected: "gemini-3-pro-high", + matched: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := matchWildcardMapping(tt.mapping, tt.requestedModel) - if result != tt.expected { - t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected) + result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel) + if result != tt.expected || matched != tt.matched { + t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched) } }) } @@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) { } } +func TestAccountResolveMappedModel(t *testing.T) { + tests := []struct { + name string + credentials map[string]any + requestedModel string + expectedModel string + expectedMatch bool + }{ + { + name: "no mapping reports unmatched", + credentials: nil, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: false, + }, + { + name: "exact passthrough mapping still counts as matched", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + }, + }, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: true, + }, + { + name: "wildcard passthrough mapping still counts as matched", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-*": "gpt-5.4", + }, + }, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: true, + }, + { + name: "missing mapping reports unmatched", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.2": "gpt-5.2", + }, + }, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Credentials: tt.credentials, + } + mappedModel, matched := account.ResolveMappedModel(tt.requestedModel) + if mappedModel != tt.expectedModel || matched != tt.expectedMatch { + t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch) + } + }) + } +} + func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) { account := &Account{ Platform: PlatformAntigravity, diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index bdd1aa4a..ea76e171 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -42,6 +42,9 @@ type AdminService interface { UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) DeleteGroup(ctx context.Context, id int64) error GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) + GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) + ClearGroupRateMultipliers(ctx context.Context, groupID int64) error + BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error // API Key management (admin) @@ -57,6 +60,8 @@ type AdminService interface { RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) ClearAccountError(ctx context.Context, id int64) (*Account, error) SetAccountError(ctx context.Context, id int64, errorMsg string) error + // EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号 privacy_mode,未设置则尝试关闭训练数据共享并持久化。 + EnsureOpenAIPrivacy(ctx context.Context, account *Account) string SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error @@ -84,6 +89,7 @@ type AdminService interface { DeleteRedeemCode(ctx context.Context, id int64) error BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) + ResetAccountQuota(ctx context.Context, id int64) error } // CreateUserInput represents input for creating a new user via admin operations. @@ -144,6 +150,9 @@ type CreateGroupInput struct { SupportedModelScopes []string // Sora 存储配额 SoraStorageQuotaBytes int64 + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch bool + DefaultMappedModel string // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -180,6 +189,9 @@ type UpdateGroupInput struct { SupportedModelScopes *[]string // Sora 存储配额 SoraStorageQuotaBytes *int64 + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch *bool + DefaultMappedModel *string // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -195,6 +207,7 @@ type CreateAccountInput struct { Concurrency int Priority int RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + LoadFactor *int GroupIDs []int64 ExpiresAt *int64 AutoPauseOnExpired *bool @@ -215,6 +228,7 @@ type UpdateAccountInput struct { Concurrency *int // 使用指针区分"未提供"和"设置为0" Priority *int // 使用指针区分"未提供"和"设置为0" RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + LoadFactor *int Status string GroupIDs *[]int64 ExpiresAt *int64 @@ -230,6 +244,7 @@ type BulkUpdateAccountsInput struct { Concurrency *int Priority *int RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + LoadFactor *int Status string Schedulable *bool GroupIDs *[]int64 @@ -353,6 +368,10 @@ type ProxyExitInfoProber interface { ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) } +type groupExistenceBatchReader interface { + ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) +} + type proxyQualityTarget struct { Target string URL string @@ -422,16 +441,14 @@ type adminServiceImpl struct { entClient *dbent.Client // 用于开启数据库事务 settingService *SettingService defaultSubAssigner DefaultSubscriptionAssigner + userSubRepo UserSubscriptionRepository + privacyClientFactory PrivacyClientFactory } type userGroupRateBatchReader interface { GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) } -type groupExistenceBatchReader interface { - ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) -} - // NewAdminService creates a new AdminService func NewAdminService( userRepo UserRepository, @@ -449,6 +466,8 @@ func NewAdminService( entClient *dbent.Client, settingService *SettingService, defaultSubAssigner DefaultSubscriptionAssigner, + userSubRepo UserSubscriptionRepository, + privacyClientFactory PrivacyClientFactory, ) AdminService { return &adminServiceImpl{ userRepo: userRepo, @@ -466,6 +485,8 @@ func NewAdminService( entClient: entClient, settingService: settingService, defaultSubAssigner: defaultSubAssigner, + userSubRepo: userSubRepo, + privacyClientFactory: privacyClientFactory, } } @@ -745,7 +766,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) + keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{}) if err != nil { return nil, 0, err } @@ -811,7 +832,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn subscriptionType = SubscriptionTypeStandard } - // 限额字段:0 和 nil 都表示"无限制" + // 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额 dailyLimit := normalizeLimit(input.DailyLimitUSD) weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) @@ -905,6 +926,8 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn MCPXMLInject: mcpXMLInject, SupportedModelScopes: input.SupportedModelScopes, SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, + AllowMessagesDispatch: input.AllowMessagesDispatch, + DefaultMappedModel: input.DefaultMappedModel, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -921,9 +944,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn return group, nil } -// normalizeLimit 将 0 或负数转换为 nil(表示无限制) +// normalizeLimit 将负数转换为 nil(表示无限制),0 保留(表示限额为零) func normalizeLimit(limit *float64) *float64 { - if limit == nil || *limit <= 0 { + if limit == nil || *limit < 0 { return nil } return limit @@ -1035,16 +1058,11 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.SubscriptionType != "" { group.SubscriptionType = input.SubscriptionType } - // 限额字段:0 和 nil 都表示"无限制",正数表示具体限额 - if input.DailyLimitUSD != nil { - group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD) - } - if input.WeeklyLimitUSD != nil { - group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) - } - if input.MonthlyLimitUSD != nil { - group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) - } + // 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额 + // 前端始终发送这三个字段,无需 nil 守卫 + group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD) + group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) + group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) // 图片生成计费配置:负数表示清除(使用默认价格) if input.ImagePrice1K != nil { group.ImagePrice1K = normalizePrice(input.ImagePrice1K) @@ -1118,6 +1136,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.SupportedModelScopes = *input.SupportedModelScopes } + // OpenAI Messages 调度配置 + if input.AllowMessagesDispatch != nil { + group.AllowMessagesDispatch = *input.AllowMessagesDispatch + } + if input.DefaultMappedModel != nil { + group.DefaultMappedModel = *input.DefaultMappedModel + } + if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } @@ -1221,6 +1247,27 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p return keys, result.Total, nil } +func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) { + if s.userGroupRateRepo == nil { + return nil, nil + } + return s.userGroupRateRepo.GetByGroupID(ctx, groupID) +} + +func (s *adminServiceImpl) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error { + if s.userGroupRateRepo == nil { + return nil + } + return s.userGroupRateRepo.DeleteByGroupID(ctx, groupID) +} + +func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error { + if s.userGroupRateRepo == nil { + return nil + } + return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) +} + func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { return s.groupRepo.UpdateSortOrders(ctx, updates) } @@ -1257,9 +1304,17 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i if group.Status != StatusActive { return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") } - // 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程 + // 订阅类型分组:用户须持有该分组的有效订阅才可绑定 if group.IsSubscriptionType() { - return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow") + if s.userSubRepo == nil { + return nil, infraerrors.InternalServer("SUBSCRIPTION_REPOSITORY_UNAVAILABLE", "subscription repository is not configured") + } + if _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, apiKey.UserID, *groupID); err != nil { + if errors.Is(err, ErrSubscriptionNotFound) { + return nil, infraerrors.BadRequest("SUBSCRIPTION_REQUIRED", "user does not have an active subscription for this group") + } + return nil, err + } } gid := *groupID @@ -1267,7 +1322,7 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i apiKey.Group = group // 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性 - if group.IsExclusive { + if group.IsExclusive && !group.IsSubscriptionType() { opCtx := ctx var tx *dbent.Tx if s.entClient == nil { @@ -1329,6 +1384,10 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, if err != nil { return nil, 0, err } + now := time.Now() + for i := range accounts { + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now) + } return accounts, result.Total, nil } @@ -1398,6 +1457,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou Status: StatusActive, Schedulable: true, } + // 预计算固定时间重置的下次重置时间 + if account.Extra != nil { + if err := ValidateQuotaResetConfig(account.Extra); err != nil { + return nil, err + } + ComputeQuotaResetAt(account.Extra) + } if input.ExpiresAt != nil && *input.ExpiresAt > 0 { expiresAt := time.Unix(*input.ExpiresAt, 0) account.ExpiresAt = &expiresAt @@ -1413,6 +1479,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } account.RateMultiplier = input.RateMultiplier } + if input.LoadFactor != nil && *input.LoadFactor > 0 { + if *input.LoadFactor > 10000 { + return nil, errors.New("load_factor must be <= 10000") + } + account.LoadFactor = input.LoadFactor + } if err := s.accountRepo.Create(ctx, account); err != nil { return nil, err } @@ -1444,6 +1516,7 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U if err != nil { return nil, err } + wasOveragesEnabled := account.IsOveragesEnabled() if input.Name != "" { account.Name = input.Name @@ -1458,7 +1531,29 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.Credentials = input.Credentials } if len(input.Extra) > 0 { + // 保留配额用量字段,防止编辑账号时意外重置 + for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} { + if v, ok := account.Extra[key]; ok { + input.Extra[key] = v + } + } account.Extra = input.Extra + if account.Platform == PlatformAntigravity && wasOveragesEnabled && !account.IsOveragesEnabled() { + delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态 + // 清除 AICredits 限流 key + if rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any); ok { + delete(rawLimits, creditsExhaustedKey) + } + } + if account.Platform == PlatformAntigravity && !wasOveragesEnabled && account.IsOveragesEnabled() { + delete(account.Extra, modelRateLimitsKey) + delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态 + } + // 校验并预计算固定时间重置的下次重置时间 + if err := ValidateQuotaResetConfig(account.Extra); err != nil { + return nil, err + } + ComputeQuotaResetAt(account.Extra) } if input.ProxyID != nil { // 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图) @@ -1483,6 +1578,15 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U } account.RateMultiplier = input.RateMultiplier } + if input.LoadFactor != nil { + if *input.LoadFactor <= 0 { + account.LoadFactor = nil // 0 或负数表示清除 + } else if *input.LoadFactor > 10000 { + return nil, errors.New("load_factor must be <= 10000") + } else { + account.LoadFactor = input.LoadFactor + } + } if input.Status != "" { account.Status = input.Status } @@ -1616,6 +1720,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp if input.RateMultiplier != nil { repoUpdates.RateMultiplier = input.RateMultiplier } + if input.LoadFactor != nil { + if *input.LoadFactor <= 0 { + repoUpdates.LoadFactor = nil // 0 或负数表示清除 + } else if *input.LoadFactor > 10000 { + return nil, errors.New("load_factor must be <= 10000") + } else { + repoUpdates.LoadFactor = input.LoadFactor + } + } if input.Status != "" { repoUpdates.Status = &input.Status } @@ -1669,16 +1782,10 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int } func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) { - account, err := s.accountRepo.GetByID(ctx, id) - if err != nil { + if err := s.accountRepo.ClearError(ctx, id); err != nil { return nil, err } - account.Status = StatusActive - account.ErrorMessage = "" - if err := s.accountRepo.Update(ctx, account); err != nil { - return nil, err - } - return account, nil + return s.accountRepo.GetByID(ctx, id) } func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error { @@ -2028,7 +2135,6 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr ProxyURL: proxyURL, Timeout: proxyQualityRequestTimeout, ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, - ProxyStrict: true, }) if err != nil { result.Items = append(result.Items, ProxyQualityCheckItem{ @@ -2440,3 +2546,43 @@ func (e *MixedChannelError) Error() string { return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.", e.GroupName, e.CurrentPlatform, e.OtherPlatform) } + +func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error { + return s.accountRepo.ResetQuotaUsed(ctx, id) +} + +// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号是否已设置 privacy_mode, +// 未设置则调用 disableOpenAITraining 并持久化到 Extra,返回设置的 mode 值。 +func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Account) string { + if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { + return "" + } + if s.privacyClientFactory == nil { + return "" + } + if account.Extra != nil { + if _, ok := account.Extra["privacy_mode"]; ok { + return "" + } + } + + token, _ := account.Credentials["access_token"].(string) + if token == "" { + return "" + } + + var proxyURL string + if account.ProxyID != nil { + if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil { + proxyURL = p.URL() + } + } + + mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL) + if mode == "" { + return "" + } + + _ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode}) + return mode +} diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 9210a786..88d2f492 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -32,28 +32,44 @@ func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, return s.addGroupErr } -func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected") } func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } // apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests. type apiKeyRepoStubForGroupUpdate struct { @@ -91,7 +107,7 @@ func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string) panic("unexpected") } func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } -func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { +func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { panic("unexpected") } func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { @@ -127,6 +143,15 @@ func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64 func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error { panic("unexpected") } +func (s *apiKeyRepoStubForGroupUpdate) IncrementRateLimitUsage(context.Context, int64, float64) error { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, int64) error { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) { + panic("unexpected") +} // groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests. type groupRepoStubForGroupUpdate struct { @@ -185,6 +210,29 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS panic("unexpected") } +type userSubRepoStubForGroupUpdate struct { + userSubRepoNoop + getActiveSub *UserSubscription + getActiveErr error + called bool + calledUserID int64 + calledGroupID int64 +} + +func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) { + s.called = true + s.calledUserID = userID + s.calledGroupID = groupID + if s.getActiveErr != nil { + return nil, s.getActiveErr + } + if s.getActiveSub == nil { + return nil, ErrSubscriptionNotFound + } + clone := *s.getActiveSub + return &clone, nil +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -377,14 +425,49 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) { existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} - groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}} + userRepo := &userRepoStubForGroupUpdate{} + userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo} + + // 无有效订阅时应拒绝绑定 + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.Error(t, err) + require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err)) + require.True(t, userSubRepo.called) + require.Equal(t, int64(42), userSubRepo.calledUserID) + require.Equal(t, int64(10), userSubRepo.calledGroupID) + require.False(t, userRepo.addGroupCalled) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}} userRepo := &userRepoStubForGroupUpdate{} svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo} - // 订阅类型分组应被阻止绑定 _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) require.Error(t, err) - require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err)) + require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err)) + require.False(t, userRepo.addGroupCalled) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} + userRepo := &userRepoStubForGroupUpdate{} + userSubRepo := &userSubRepoStubForGroupUpdate{ + getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10}, + } + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.NoError(t, err) + require.True(t, userSubRepo.called) + require.NotNil(t, got.APIKey.GroupID) + require.Equal(t, int64(10), *got.APIKey.GroupID) require.False(t, userRepo.addGroupCalled) } diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index bb906df5..2e0f7d90 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -348,6 +348,19 @@ func (s *billingCacheStub) InvalidateSubscriptionCache(ctx context.Context, user return nil } +func (s *billingCacheStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) { + panic("unexpected GetAPIKeyRateLimit call") +} +func (s *billingCacheStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error { + panic("unexpected SetAPIKeyRateLimit call") +} +func (s *billingCacheStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { + panic("unexpected UpdateAPIKeyRateLimitUsage call") +} +func (s *billingCacheStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { + panic("unexpected InvalidateAPIKeyRateLimit call") +} + func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall { t.Helper() calls := make([]subscriptionInvalidateCall, 0, expected) diff --git a/backend/internal/service/admin_service_group_rate_test.go b/backend/internal/service/admin_service_group_rate_test.go new file mode 100644 index 00000000..77635247 --- /dev/null +++ b/backend/internal/service/admin_service_group_rate_test.go @@ -0,0 +1,176 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// userGroupRateRepoStubForGroupRate implements UserGroupRateRepository for group rate tests. +type userGroupRateRepoStubForGroupRate struct { + getByGroupIDData map[int64][]UserGroupRateEntry + getByGroupIDErr error + + deletedGroupIDs []int64 + deleteByGroupErr error + + syncedGroupID int64 + syncedEntries []GroupRateMultiplierInput + syncGroupErr error +} + +func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) { + panic("unexpected GetByUserID call") +} + +func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, _, _ int64) (*float64, error) { + panic("unexpected GetByUserAndGroup call") +} + +func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) { + if s.getByGroupIDErr != nil { + return nil, s.getByGroupIDErr + } + return s.getByGroupIDData[groupID], nil +} + +func (s *userGroupRateRepoStubForGroupRate) SyncUserGroupRates(_ context.Context, _ int64, _ map[int64]*float64) error { + panic("unexpected SyncUserGroupRates call") +} + +func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.Context, groupID int64, entries []GroupRateMultiplierInput) error { + s.syncedGroupID = groupID + s.syncedEntries = entries + return s.syncGroupErr +} + +func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error { + s.deletedGroupIDs = append(s.deletedGroupIDs, groupID) + return s.deleteByGroupErr +} + +func (s *userGroupRateRepoStubForGroupRate) DeleteByUserID(_ context.Context, _ int64) error { + panic("unexpected DeleteByUserID call") +} + +func TestAdminService_GetGroupRateMultipliers(t *testing.T) { + t.Run("returns entries for group", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{ + getByGroupIDData: map[int64][]UserGroupRateEntry{ + 10: { + {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5}, + {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8}, + }, + }, + } + svc := &adminServiceImpl{userGroupRateRepo: repo} + + entries, err := svc.GetGroupRateMultipliers(context.Background(), 10) + require.NoError(t, err) + require.Len(t, entries, 2) + require.Equal(t, int64(1), entries[0].UserID) + require.Equal(t, "alice", entries[0].UserName) + require.Equal(t, 1.5, entries[0].RateMultiplier) + require.Equal(t, int64(2), entries[1].UserID) + require.Equal(t, 0.8, entries[1].RateMultiplier) + }) + + t.Run("returns nil when repo is nil", func(t *testing.T) { + svc := &adminServiceImpl{userGroupRateRepo: nil} + + entries, err := svc.GetGroupRateMultipliers(context.Background(), 10) + require.NoError(t, err) + require.Nil(t, entries) + }) + + t.Run("returns empty slice for group with no entries", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{ + getByGroupIDData: map[int64][]UserGroupRateEntry{}, + } + svc := &adminServiceImpl{userGroupRateRepo: repo} + + entries, err := svc.GetGroupRateMultipliers(context.Background(), 99) + require.NoError(t, err) + require.Nil(t, entries) + }) + + t.Run("propagates repo error", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{ + getByGroupIDErr: errors.New("db error"), + } + svc := &adminServiceImpl{userGroupRateRepo: repo} + + _, err := svc.GetGroupRateMultipliers(context.Background(), 10) + require.Error(t, err) + require.Contains(t, err.Error(), "db error") + }) +} + +func TestAdminService_ClearGroupRateMultipliers(t *testing.T) { + t.Run("deletes by group ID", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{} + svc := &adminServiceImpl{userGroupRateRepo: repo} + + err := svc.ClearGroupRateMultipliers(context.Background(), 42) + require.NoError(t, err) + require.Equal(t, []int64{42}, repo.deletedGroupIDs) + }) + + t.Run("returns nil when repo is nil", func(t *testing.T) { + svc := &adminServiceImpl{userGroupRateRepo: nil} + + err := svc.ClearGroupRateMultipliers(context.Background(), 42) + require.NoError(t, err) + }) + + t.Run("propagates repo error", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{ + deleteByGroupErr: errors.New("delete failed"), + } + svc := &adminServiceImpl{userGroupRateRepo: repo} + + err := svc.ClearGroupRateMultipliers(context.Background(), 42) + require.Error(t, err) + require.Contains(t, err.Error(), "delete failed") + }) +} + +func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) { + t.Run("syncs entries to repo", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{} + svc := &adminServiceImpl{userGroupRateRepo: repo} + + entries := []GroupRateMultiplierInput{ + {UserID: 1, RateMultiplier: 1.5}, + {UserID: 2, RateMultiplier: 0.8}, + } + err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, entries) + require.NoError(t, err) + require.Equal(t, int64(10), repo.syncedGroupID) + require.Equal(t, entries, repo.syncedEntries) + }) + + t.Run("returns nil when repo is nil", func(t *testing.T) { + svc := &adminServiceImpl{userGroupRateRepo: nil} + + err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, nil) + require.NoError(t, err) + }) + + t.Run("propagates repo error", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{ + syncGroupErr: errors.New("sync failed"), + } + svc := &adminServiceImpl{userGroupRateRepo: repo} + + err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, []GroupRateMultiplierInput{ + {UserID: 1, RateMultiplier: 1.0}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "sync failed") + }) +} diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go index 8b50530a..37f348df 100644 --- a/backend/internal/service/admin_service_list_users_test.go +++ b/backend/internal/service/admin_service_list_users_test.go @@ -68,7 +68,15 @@ func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context panic("unexpected SyncUserGroupRates call") } -func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error { +func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, _ int64) ([]UserGroupRateEntry, error) { + panic("unexpected GetByGroupID call") +} + +func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.Context, _ int64, _ []GroupRateMultiplierInput) error { + panic("unexpected SyncGroupRateMultipliers call") +} + +func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error { panic("unexpected DeleteByGroupID call") } diff --git a/backend/internal/service/admin_service_overages_test.go b/backend/internal/service/admin_service_overages_test.go new file mode 100644 index 00000000..779b08b9 --- /dev/null +++ b/backend/internal/service/admin_service_overages_test.go @@ -0,0 +1,123 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type updateAccountOveragesRepoStub struct { + mockAccountRepoForGemini + account *Account + updateCalls int +} + +func (r *updateAccountOveragesRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) { + return r.account, nil +} + +func (r *updateAccountOveragesRepoStub) Update(ctx context.Context, account *Account) error { + r.updateCalls++ + r.account = account + return nil +} + +func TestUpdateAccount_DisableOveragesClearsAICreditsKey(t *testing.T) { + accountID := int64(101) + repo := &updateAccountOveragesRepoStub{ + account: &Account{ + ID: accountID, + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Extra: map[string]any{ + "allow_overages": true, + "mixed_scheduling": true, + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limited_at": "2026-03-15T00:00:00Z", + "rate_limit_reset_at": "2099-03-15T00:00:00Z", + }, + creditsExhaustedKey: map[string]any{ + "rate_limited_at": "2026-03-15T00:00:00Z", + "rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339), + }, + }, + }, + }, + } + + svc := &adminServiceImpl{accountRepo: repo} + updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{ + Extra: map[string]any{ + "mixed_scheduling": true, + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limited_at": "2026-03-15T00:00:00Z", + "rate_limit_reset_at": "2099-03-15T00:00:00Z", + }, + creditsExhaustedKey: map[string]any{ + "rate_limited_at": "2026-03-15T00:00:00Z", + "rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339), + }, + }, + }, + }) + + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, 1, repo.updateCalls) + require.False(t, updated.IsOveragesEnabled()) + + // 关闭 overages 后,AICredits key 应被清除 + rawLimits, ok := repo.account.Extra[modelRateLimitsKey].(map[string]any) + if ok { + _, exists := rawLimits[creditsExhaustedKey] + require.False(t, exists, "关闭 overages 时应清除 AICredits 限流 key") + } + // 普通模型限流应保留 + require.True(t, ok) + _, exists := rawLimits["claude-sonnet-4-5"] + require.True(t, exists, "普通模型限流应保留") +} + +func TestUpdateAccount_EnableOveragesClearsModelRateLimitsBeforePersist(t *testing.T) { + accountID := int64(102) + repo := &updateAccountOveragesRepoStub{ + account: &Account{ + ID: accountID, + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Extra: map[string]any{ + "mixed_scheduling": true, + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limited_at": "2026-03-15T00:00:00Z", + "rate_limit_reset_at": "2099-03-15T00:00:00Z", + }, + }, + }, + }, + } + + svc := &adminServiceImpl{accountRepo: repo} + updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{ + Extra: map[string]any{ + "mixed_scheduling": true, + "allow_overages": true, + }, + }) + + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, 1, repo.updateCalls) + require.True(t, updated.IsOveragesEnabled()) + + _, exists := repo.account.Extra[modelRateLimitsKey] + require.False(t, exists, "开启 overages 时应在持久化前清掉旧模型限流") +} diff --git a/backend/internal/service/announcement.go b/backend/internal/service/announcement.go index 2ba5af5d..25c66eb4 100644 --- a/backend/internal/service/announcement.go +++ b/backend/internal/service/announcement.go @@ -14,6 +14,11 @@ const ( AnnouncementStatusArchived = domain.AnnouncementStatusArchived ) +const ( + AnnouncementNotifyModeSilent = domain.AnnouncementNotifyModeSilent + AnnouncementNotifyModePopup = domain.AnnouncementNotifyModePopup +) + const ( AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance diff --git a/backend/internal/service/announcement_service.go b/backend/internal/service/announcement_service.go index c2588e6c..c0a0681a 100644 --- a/backend/internal/service/announcement_service.go +++ b/backend/internal/service/announcement_service.go @@ -33,23 +33,25 @@ func NewAnnouncementService( } type CreateAnnouncementInput struct { - Title string - Content string - Status string - Targeting AnnouncementTargeting - StartsAt *time.Time - EndsAt *time.Time - ActorID *int64 // 管理员用户ID + Title string + Content string + Status string + NotifyMode string + Targeting AnnouncementTargeting + StartsAt *time.Time + EndsAt *time.Time + ActorID *int64 // 管理员用户ID } type UpdateAnnouncementInput struct { - Title *string - Content *string - Status *string - Targeting *AnnouncementTargeting - StartsAt **time.Time - EndsAt **time.Time - ActorID *int64 // 管理员用户ID + Title *string + Content *string + Status *string + NotifyMode *string + Targeting *AnnouncementTargeting + StartsAt **time.Time + EndsAt **time.Time + ActorID *int64 // 管理员用户ID } type UserAnnouncement struct { @@ -93,6 +95,14 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem return nil, err } + notifyMode := strings.TrimSpace(input.NotifyMode) + if notifyMode == "" { + notifyMode = AnnouncementNotifyModeSilent + } + if !isValidAnnouncementNotifyMode(notifyMode) { + return nil, fmt.Errorf("create announcement: invalid notify_mode") + } + if input.StartsAt != nil && input.EndsAt != nil { if !input.StartsAt.Before(*input.EndsAt) { return nil, fmt.Errorf("create announcement: starts_at must be before ends_at") @@ -100,12 +110,13 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem } a := &Announcement{ - Title: title, - Content: content, - Status: status, - Targeting: targeting, - StartsAt: input.StartsAt, - EndsAt: input.EndsAt, + Title: title, + Content: content, + Status: status, + NotifyMode: notifyMode, + Targeting: targeting, + StartsAt: input.StartsAt, + EndsAt: input.EndsAt, } if input.ActorID != nil && *input.ActorID > 0 { a.CreatedBy = input.ActorID @@ -150,6 +161,14 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat a.Status = status } + if input.NotifyMode != nil { + notifyMode := strings.TrimSpace(*input.NotifyMode) + if !isValidAnnouncementNotifyMode(notifyMode) { + return nil, fmt.Errorf("update announcement: invalid notify_mode") + } + a.NotifyMode = notifyMode + } + if input.Targeting != nil { targeting, err := domain.AnnouncementTargeting(*input.Targeting).NormalizeAndValidate() if err != nil { @@ -376,3 +395,12 @@ func isValidAnnouncementStatus(status string) bool { return false } } + +func isValidAnnouncementNotifyMode(mode string) bool { + switch mode { + case AnnouncementNotifyModeSilent, AnnouncementNotifyModePopup: + return true + default: + return false + } +} diff --git a/backend/internal/service/antigravity_credits_overages.go b/backend/internal/service/antigravity_credits_overages.go new file mode 100644 index 00000000..1521dfcd --- /dev/null +++ b/backend/internal/service/antigravity_credits_overages.go @@ -0,0 +1,234 @@ +package service + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +const ( + // creditsExhaustedKey 是 model_rate_limits 中标记积分耗尽的特殊 key。 + // 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。 + creditsExhaustedKey = "AICredits" + creditsExhaustedDuration = 5 * time.Hour +) + +type antigravity429Category string + +const ( + antigravity429Unknown antigravity429Category = "unknown" + antigravity429RateLimited antigravity429Category = "rate_limited" + antigravity429QuotaExhausted antigravity429Category = "quota_exhausted" +) + +var ( + antigravityQuotaExhaustedKeywords = []string{ + "quota_exhausted", + "quota exhausted", + } + + creditsExhaustedKeywords = []string{ + "google_one_ai", + "insufficient credit", + "insufficient credits", + "not enough credit", + "not enough credits", + "credit exhausted", + "credits exhausted", + "credit balance", + "minimumcreditamountforusage", + "minimum credit amount for usage", + "minimum credit", + } +) + +// isCreditsExhausted 检查账号的 AICredits 限流 key 是否生效(积分是否耗尽)。 +func (a *Account) isCreditsExhausted() bool { + if a == nil { + return false + } + return a.isRateLimitActiveForKey(creditsExhaustedKey) +} + +// setCreditsExhausted 标记账号积分耗尽:写入 model_rate_limits["AICredits"] + 更新缓存。 +func (s *AntigravityGatewayService) setCreditsExhausted(ctx context.Context, account *Account) { + if account == nil || account.ID == 0 { + return + } + resetAt := time.Now().Add(creditsExhaustedDuration) + if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, creditsExhaustedKey, resetAt); err != nil { + logger.LegacyPrintf("service.antigravity_gateway", "set credits exhausted failed: account=%d err=%v", account.ID, err) + return + } + s.updateAccountModelRateLimitInCache(ctx, account, creditsExhaustedKey, resetAt) + logger.LegacyPrintf("service.antigravity_gateway", "credits_exhausted_marked account=%d reset_at=%s", + account.ID, resetAt.UTC().Format(time.RFC3339)) +} + +// clearCreditsExhausted 清除账号的 AICredits 限流 key。 +func (s *AntigravityGatewayService) clearCreditsExhausted(ctx context.Context, account *Account) { + if account == nil || account.ID == 0 || account.Extra == nil { + return + } + rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any) + if !ok { + return + } + if _, exists := rawLimits[creditsExhaustedKey]; !exists { + return + } + delete(rawLimits, creditsExhaustedKey) + account.Extra[modelRateLimitsKey] = rawLimits + if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{ + modelRateLimitsKey: rawLimits, + }); err != nil { + logger.LegacyPrintf("service.antigravity_gateway", "clear credits exhausted failed: account=%d err=%v", account.ID, err) + } +} + +// classifyAntigravity429 将 Antigravity 的 429 响应归类为配额耗尽、限流或未知。 +func classifyAntigravity429(body []byte) antigravity429Category { + if len(body) == 0 { + return antigravity429Unknown + } + lowerBody := strings.ToLower(string(body)) + for _, keyword := range antigravityQuotaExhaustedKeywords { + if strings.Contains(lowerBody, keyword) { + return antigravity429QuotaExhausted + } + } + if info := parseAntigravitySmartRetryInfo(body); info != nil && !info.IsModelCapacityExhausted { + return antigravity429RateLimited + } + return antigravity429Unknown +} + +// injectEnabledCreditTypes 在已序列化的 v1internal JSON body 中注入 AI Credits 类型。 +func injectEnabledCreditTypes(body []byte) []byte { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return nil + } + payload["enabledCreditTypes"] = []string{"GOOGLE_ONE_AI"} + result, err := json.Marshal(payload) + if err != nil { + return nil + } + return result +} + +// resolveCreditsOveragesModelKey 解析当前请求对应的 overages 状态模型 key。 +func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstreamModelName, requestedModel string) string { + modelKey := strings.TrimSpace(upstreamModelName) + if modelKey != "" { + return modelKey + } + if account == nil { + return "" + } + modelKey = resolveFinalAntigravityModelKey(ctx, account, requestedModel) + if strings.TrimSpace(modelKey) != "" { + return modelKey + } + return resolveAntigravityModelKey(requestedModel) +} + +// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。 +func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool { + if reqErr != nil || resp == nil { + return false + } + if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout { + return false + } + if isURLLevelRateLimit(respBody) { + return false + } + if info := parseAntigravitySmartRetryInfo(respBody); info != nil { + return false + } + bodyLower := strings.ToLower(string(respBody)) + for _, keyword := range creditsExhaustedKeywords { + if strings.Contains(bodyLower, keyword) { + return true + } + } + return false +} + +type creditsOveragesRetryResult struct { + handled bool + resp *http.Response +} + +// attemptCreditsOveragesRetry 在确认免费配额耗尽后,尝试注入 AI Credits 继续请求。 +func (s *AntigravityGatewayService) attemptCreditsOveragesRetry( + p antigravityRetryLoopParams, + baseURL string, + modelName string, + waitDuration time.Duration, + originalStatusCode int, + respBody []byte, +) *creditsOveragesRetryResult { + creditsBody := injectEnabledCreditTypes(p.body) + if creditsBody == nil { + return &creditsOveragesRetryResult{handled: false} + } + modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)", + p.prefix, modelKey, p.account.ID) + + creditsReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, creditsBody) + if err != nil { + logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d build_request_err=%v", + p.prefix, modelKey, p.account.ID, err) + return &creditsOveragesRetryResult{handled: true} + } + + creditsResp, err := p.httpUpstream.Do(creditsReq, p.proxyURL, p.account.ID, p.account.Concurrency) + if err == nil && creditsResp != nil && creditsResp.StatusCode < 400 { + s.clearCreditsExhausted(p.ctx, p.account) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d credit_overages_success model=%s account=%d", + p.prefix, creditsResp.StatusCode, modelKey, p.account.ID) + return &creditsOveragesRetryResult{handled: true, resp: creditsResp} + } + + s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, creditsResp, err) + return &creditsOveragesRetryResult{handled: true} +} + +func (s *AntigravityGatewayService) handleCreditsRetryFailure( + ctx context.Context, + prefix string, + modelKey string, + account *Account, + creditsResp *http.Response, + reqErr error, +) { + var creditsRespBody []byte + creditsStatusCode := 0 + if creditsResp != nil { + creditsStatusCode = creditsResp.StatusCode + if creditsResp.Body != nil { + creditsRespBody, _ = io.ReadAll(io.LimitReader(creditsResp.Body, 64<<10)) + _ = creditsResp.Body.Close() + } + } + + if shouldMarkCreditsExhausted(creditsResp, creditsRespBody, reqErr) && account != nil { + s.setCreditsExhausted(ctx, account) + logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=true status=%d body=%s", + prefix, modelKey, account.ID, creditsStatusCode, truncateForLog(creditsRespBody, 200)) + return + } + if account != nil { + logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=false status=%d err=%v body=%s", + prefix, modelKey, account.ID, creditsStatusCode, reqErr, truncateForLog(creditsRespBody, 200)) + } +} diff --git a/backend/internal/service/antigravity_credits_overages_test.go b/backend/internal/service/antigravity_credits_overages_test.go new file mode 100644 index 00000000..bc679494 --- /dev/null +++ b/backend/internal/service/antigravity_credits_overages_test.go @@ -0,0 +1,538 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/stretchr/testify/require" +) + +func TestClassifyAntigravity429(t *testing.T) { + t.Run("明确配额耗尽", func(t *testing.T) { + body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`) + require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body)) + }) + + t.Run("结构化限流", func(t *testing.T) { + body := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }`) + require.Equal(t, antigravity429RateLimited, classifyAntigravity429(body)) + }) + + t.Run("未知429", func(t *testing.T) { + body := []byte(`{"error":{"message":"too many requests"}}`) + require.Equal(t, antigravity429Unknown, classifyAntigravity429(body)) + }) +} + +func TestIsCreditsExhausted_UsesAICreditsKey(t *testing.T) { + t.Run("无 AICredits key 则积分可用", func(t *testing.T) { + account := &Account{ + ID: 1, + Platform: PlatformAntigravity, + Extra: map[string]any{ + "allow_overages": true, + }, + } + require.False(t, account.isCreditsExhausted()) + }) + + t.Run("AICredits key 生效则积分耗尽", func(t *testing.T) { + account := &Account{ + ID: 2, + Platform: PlatformAntigravity, + Extra: map[string]any{ + "allow_overages": true, + modelRateLimitsKey: map[string]any{ + creditsExhaustedKey: map[string]any{ + "rate_limited_at": time.Now().UTC().Format(time.RFC3339), + "rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339), + }, + }, + }, + } + require.True(t, account.isCreditsExhausted()) + }) + + t.Run("AICredits key 过期则积分可用", func(t *testing.T) { + account := &Account{ + ID: 3, + Platform: PlatformAntigravity, + Extra: map[string]any{ + "allow_overages": true, + modelRateLimitsKey: map[string]any{ + creditsExhaustedKey: map[string]any{ + "rate_limited_at": time.Now().Add(-6 * time.Hour).UTC().Format(time.RFC3339), + "rate_limit_reset_at": time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339), + }, + }, + }, + } + require.False(t, account.isCreditsExhausted()) + }) +} + +func TestHandleSmartRetry_QuotaExhausted_UsesCreditsAndStoresIndependentState(t *testing.T) { + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"ok":true}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 101, + Name: "acc-101", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Extra: map[string]any{ + "allow_overages": true, + }, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-opus-4-6": "claude-sonnet-4-5", + }, + }, + } + + respBody := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"model":"claude-opus-4-6","request":{}}`), + httpUpstream: upstream, + accountRepo: repo, + requestedModel: "claude-opus-4-6", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"}) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp) + require.Nil(t, result.switchError) + require.Len(t, upstream.requestBodies, 1) + require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes") + require.Empty(t, repo.modelRateLimitCalls, "overages 成功后不应写入普通 model_rate_limits") +} + +func TestHandleSmartRetry_RateLimited_DoesNotUseCredits(t *testing.T) { + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"ok":true}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 102, + Name: "acc-102", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Extra: map[string]any{ + "allow_overages": true, + }, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`), + httpUpstream: upstream, + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"}) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp) + require.Len(t, upstream.requestBodies, 1) + require.NotContains(t, string(upstream.requestBodies[0]), "enabledCreditTypes") + require.Empty(t, repo.extraUpdateCalls) + require.Empty(t, repo.modelRateLimitCalls) +} + +func TestAntigravityRetryLoop_ModelRateLimited_InjectsCredits(t *testing.T) { + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) + oldAvailability := antigravity.DefaultURLAvailability + defer func() { + antigravity.BaseURLs = oldBaseURLs + antigravity.DefaultURLAvailability = oldAvailability + }() + + antigravity.BaseURLs = []string{"https://ag-1.test"} + antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute) + + upstream := &queuedHTTPUpstreamStub{ + responses: []*http.Response{ + { + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"ok":true}`)), + }, + }, + errors: []error{nil}, + } + // 模型已限流 + overages 启用 + 无 AICredits key → 应直接注入积分 + account := &Account{ + ID: 103, + Name: "acc-103", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "allow_overages": true, + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limited_at": time.Now().UTC().Format(time.RFC3339), + "rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339), + }, + }, + }, + } + + svc := &AntigravityGatewayService{} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`), + httpUpstream: upstream, + requestedModel: "claude-sonnet-4-5", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, upstream.requestBodies, 1) + require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes") +} + +func TestAntigravityRetryLoop_CreditsExhausted_DoesNotInject(t *testing.T) { + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) + oldAvailability := antigravity.DefaultURLAvailability + defer func() { + antigravity.BaseURLs = oldBaseURLs + antigravity.DefaultURLAvailability = oldAvailability + }() + + antigravity.BaseURLs = []string{"https://ag-1.test"} + antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute) + + // 模型限流 + overages 启用 + AICredits key 生效 → 不应注入积分,应切号 + account := &Account{ + ID: 104, + Name: "acc-104", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "allow_overages": true, + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limited_at": time.Now().UTC().Format(time.RFC3339), + "rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339), + }, + creditsExhaustedKey: map[string]any{ + "rate_limited_at": time.Now().UTC().Format(time.RFC3339), + "rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339), + }, + }, + }, + } + + svc := &AntigravityGatewayService{} + _, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`), + requestedModel: "claude-sonnet-4-5", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + // 模型限流 + 积分耗尽 → 应触发切号错误 + require.Error(t, err) + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) +} + +func TestAntigravityRetryLoop_CreditErrorMarksExhausted(t *testing.T) { + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) + oldAvailability := antigravity.DefaultURLAvailability + defer func() { + antigravity.BaseURLs = oldBaseURLs + antigravity.DefaultURLAvailability = oldAvailability + }() + + antigravity.BaseURLs = []string{"https://ag-1.test"} + antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute) + + repo := &stubAntigravityAccountRepo{} + upstream := &queuedHTTPUpstreamStub{ + responses: []*http.Response{ + { + StatusCode: http.StatusForbidden, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`)), + }, + }, + errors: []error{nil}, + } + // 模型限流 + overages 启用 + 积分可用 → 注入积分但上游返回积分不足 + account := &Account{ + ID: 105, + Name: "acc-105", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "allow_overages": true, + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limited_at": time.Now().UTC().Format(time.RFC3339), + "rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339), + }, + }, + }, + } + + svc := &AntigravityGatewayService{accountRepo: repo} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`), + httpUpstream: upstream, + accountRepo: repo, + requestedModel: "claude-sonnet-4-5", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.NoError(t, err) + require.NotNil(t, result) + // 验证 AICredits key 已通过 SetModelRateLimit 写入数据库 + require.Len(t, repo.modelRateLimitCalls, 1, "应通过 SetModelRateLimit 写入 AICredits key") + require.Equal(t, creditsExhaustedKey, repo.modelRateLimitCalls[0].modelKey) +} + +func TestShouldMarkCreditsExhausted(t *testing.T) { + t.Run("reqErr 不为 nil 时不标记", func(t *testing.T) { + resp := &http.Response{StatusCode: http.StatusForbidden} + require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), io.ErrUnexpectedEOF)) + }) + + t.Run("resp 为 nil 时不标记", func(t *testing.T) { + require.False(t, shouldMarkCreditsExhausted(nil, []byte(`{"error":"Insufficient credits"}`), nil)) + }) + + t.Run("5xx 响应不标记", func(t *testing.T) { + resp := &http.Response{StatusCode: http.StatusInternalServerError} + require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil)) + }) + + t.Run("408 RequestTimeout 不标记", func(t *testing.T) { + resp := &http.Response{StatusCode: http.StatusRequestTimeout} + require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil)) + }) + + t.Run("URL 级限流不标记", func(t *testing.T) { + resp := &http.Response{StatusCode: http.StatusTooManyRequests} + body := []byte(`{"error":{"message":"Resource has been exhausted"}}`) + require.False(t, shouldMarkCreditsExhausted(resp, body, nil)) + }) + + t.Run("结构化限流不标记", func(t *testing.T) { + resp := &http.Response{StatusCode: http.StatusTooManyRequests} + body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`) + require.False(t, shouldMarkCreditsExhausted(resp, body, nil)) + }) + + t.Run("含 credits 关键词时标记", func(t *testing.T) { + resp := &http.Response{StatusCode: http.StatusForbidden} + for _, keyword := range []string{ + "Insufficient GOOGLE_ONE_AI credits", + "insufficient credit balance", + "not enough credits for this request", + "Credits exhausted", + "minimumCreditAmountForUsage requirement not met", + } { + body := []byte(`{"error":{"message":"` + keyword + `"}}`) + require.True(t, shouldMarkCreditsExhausted(resp, body, nil), "should mark for keyword: %s", keyword) + } + }) + + t.Run("无 credits 关键词时不标记", func(t *testing.T) { + resp := &http.Response{StatusCode: http.StatusForbidden} + body := []byte(`{"error":{"message":"permission denied"}}`) + require.False(t, shouldMarkCreditsExhausted(resp, body, nil)) + }) +} + +func TestInjectEnabledCreditTypes(t *testing.T) { + t.Run("正常 JSON 注入成功", func(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-5","request":{}}`) + result := injectEnabledCreditTypes(body) + require.NotNil(t, result) + require.Contains(t, string(result), `"enabledCreditTypes"`) + require.Contains(t, string(result), `GOOGLE_ONE_AI`) + }) + + t.Run("非法 JSON 返回 nil", func(t *testing.T) { + require.Nil(t, injectEnabledCreditTypes([]byte(`not json`))) + }) + + t.Run("空 body 返回 nil", func(t *testing.T) { + require.Nil(t, injectEnabledCreditTypes([]byte{})) + }) + + t.Run("已有 enabledCreditTypes 会被覆盖", func(t *testing.T) { + body := []byte(`{"enabledCreditTypes":["OLD"],"model":"test"}`) + result := injectEnabledCreditTypes(body) + require.NotNil(t, result) + require.Contains(t, string(result), `GOOGLE_ONE_AI`) + require.NotContains(t, string(result), `OLD`) + }) +} + +func TestClearCreditsExhausted(t *testing.T) { + t.Run("account 为 nil 不操作", func(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + svc.clearCreditsExhausted(context.Background(), nil) + require.Empty(t, repo.extraUpdateCalls) + }) + + t.Run("Extra 为 nil 不操作", func(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + svc.clearCreditsExhausted(context.Background(), &Account{ID: 1}) + require.Empty(t, repo.extraUpdateCalls) + }) + + t.Run("无 modelRateLimitsKey 不操作", func(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + svc.clearCreditsExhausted(context.Background(), &Account{ + ID: 1, + Extra: map[string]any{"some_key": "value"}, + }) + require.Empty(t, repo.extraUpdateCalls) + }) + + t.Run("无 AICredits key 不操作", func(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + svc.clearCreditsExhausted(context.Background(), &Account{ + ID: 1, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limited_at": "2026-03-15T00:00:00Z", + "rate_limit_reset_at": "2099-03-15T00:00:00Z", + }, + }, + }, + }) + require.Empty(t, repo.extraUpdateCalls) + }) + + t.Run("有 AICredits key 时删除并调用 UpdateExtra", func(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ + ID: 1, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limited_at": "2026-03-15T00:00:00Z", + "rate_limit_reset_at": "2099-03-15T00:00:00Z", + }, + creditsExhaustedKey: map[string]any{ + "rate_limited_at": "2026-03-15T00:00:00Z", + "rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339), + }, + }, + }, + } + svc.clearCreditsExhausted(context.Background(), account) + require.Len(t, repo.extraUpdateCalls, 1) + // AICredits key 应被删除 + rawLimits := account.Extra[modelRateLimitsKey].(map[string]any) + _, exists := rawLimits[creditsExhaustedKey] + require.False(t, exists, "AICredits key 应被删除") + // 普通模型限流应保留 + _, exists = rawLimits["claude-sonnet-4-5"] + require.True(t, exists, "普通模型限流应保留") + }) +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 96ff3354..cafc2a79 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -188,9 +188,29 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam return &smartRetryResult{action: smartRetryActionContinueURL} } + category := antigravity429Unknown + if resp.StatusCode == http.StatusTooManyRequests { + category = classifyAntigravity429(respBody) + } + // 判断是否触发智能重试 shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody) + // AI Credits 超量请求: + // 仅在上游明确返回免费配额耗尽时才允许切换到 credits。 + if resp.StatusCode == http.StatusTooManyRequests && + category == antigravity429QuotaExhausted && + p.account.IsOveragesEnabled() && + !p.account.isCreditsExhausted() { + result := s.attemptCreditsOveragesRetry(p, baseURL, modelName, waitDuration, resp.StatusCode, respBody) + if result.handled && result.resp != nil { + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + resp: result.resp, + } + } + } + // 情况1: retryDelay >= 阈值,限流模型并切换账号 if shouldRateLimitModel { // 单账号 503 退避重试模式:不设限流、不切换账号,改为原地等待+重试 @@ -532,14 +552,31 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( // antigravityRetryLoop 执行带 URL fallback 的重试循环 func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { + // 预检查:模型限流 + overages 启用 + 积分未耗尽 → 直接注入 AI Credits + overagesInjected := false + if p.requestedModel != "" && p.account.Platform == PlatformAntigravity && + p.account.IsOveragesEnabled() && !p.account.isCreditsExhausted() && + p.account.isModelRateLimitedWithContext(p.ctx, p.requestedModel) { + if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil { + p.body = creditsBody + overagesInjected = true + logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: model_rate_limited_credits_inject model=%s account=%d (injecting enabledCreditTypes)", + p.prefix, p.requestedModel, p.account.ID) + } + } + // 预检查:如果账号已限流,直接返回切换信号 if p.requestedModel != "" { if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 { - // 单账号 503 退避重试模式:跳过限流预检查,直接发请求。 - // 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。 - // 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace - // 会在 Service 层原地等待+重试,不需要在预检查这里等。 - if isSingleAccountRetry(p.ctx) { + // 已注入积分的请求不再受普通模型限流预检查阻断。 + if overagesInjected { + logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: credits_injected_ignore_rate_limit remaining=%v model=%s account=%d", + p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) + } else if isSingleAccountRetry(p.ctx) { + // 单账号 503 退避重试模式:跳过限流预检查,直接发请求。 + // 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。 + // 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace + // 会在 Service 层原地等待+重试,不需要在预检查这里等。 logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)", p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) } else { @@ -631,6 +668,15 @@ urlFallbackLoop: respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() + if overagesInjected && shouldMarkCreditsExhausted(resp, respBody, nil) { + modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, "", p.requestedModel) + s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + }, nil) + } + // ★ 统一入口:自定义错误码 + 临时不可调度 if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled { if policyErr != nil { @@ -1384,7 +1430,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 优先检测 thinking block 的 signature 相关错误(400)并重试一次: // Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验, // 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。 - if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) { + if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) logBody, maxBytes := s.getLogConfig() @@ -1517,6 +1563,80 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } } + // Budget 整流:检测 budget_tokens 约束错误并自动修正重试 + if resp.StatusCode == http.StatusBadRequest && respBody != nil && !isSignatureRelatedError(respBody) { + errMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "budget_constraint_error", + Message: errMsg, + Detail: s.getUpstreamErrorDetail(respBody), + }) + + // 修正 claudeReq 的 thinking 参数(adaptive 模式不修正) + if claudeReq.Thinking == nil || claudeReq.Thinking.Type != "adaptive" { + retryClaudeReq := claudeReq + retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...) + // 创建新的 ThinkingConfig 避免修改原始 claudeReq.Thinking 指针 + retryClaudeReq.Thinking = &antigravity.ThinkingConfig{ + Type: "enabled", + BudgetTokens: BudgetRectifyBudgetTokens, + } + if retryClaudeReq.MaxTokens < BudgetRectifyMinMaxTokens { + retryClaudeReq.MaxTokens = BudgetRectifyMaxTokens + } + + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens) + + retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, transformOpts) + if txErr == nil { + retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: action, + body: retryGeminiBody, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + accountRepo: s.accountRepo, + handleError: s.handleUpstreamError, + requestedModel: originalModel, + isStickySession: isStickySession, + groupID: 0, + sessionHash: "", + }) + if retryErr == nil { + retryResp := retryResult.resp + if retryResp.StatusCode < 400 { + _ = resp.Body.Close() + resp = retryResp + respBody = nil + } else { + retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + _ = retryResp.Body.Close() + respBody = retryBody + resp = &http.Response{ + StatusCode: retryResp.StatusCode, + Header: retryResp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(retryBody)), + } + } + } else { + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: budget rectifier retry failed: %v", account.ID, retryErr) + } + } + } + } + } + // 处理错误响应(重试后仍失败或不触发重试) if resp.StatusCode >= 400 { // 检测 prompt too long 错误,返回特殊错误类型供上层 fallback @@ -2090,6 +2210,112 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co } } + // Gemini 原生请求中的 thoughtSignature 可能来自旧上下文/旧账号,触发上游严格校验后返回 + // "Corrupted thought signature."。检测到此类 400 时,将 thoughtSignature 清理为 dummy 值后重试一次。 + signatureCheckBody := respBody + if unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody); unwrapErr == nil && len(unwrapped) > 0 { + signatureCheckBody = unwrapped + } + if resp.StatusCode == http.StatusBadRequest && + s.settingService != nil && + s.settingService.IsSignatureRectifierEnabled(ctx) && + isSignatureRelatedError(signatureCheckBody) && + bytes.Contains(injectedBody, []byte(`"thoughtSignature"`)) { + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(signatureCheckBody))) + upstreamDetail := s.getUpstreamErrorDetail(signatureCheckBody) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "signature_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID) + + cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody) + retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody) + if wrapErr == nil { + retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: upstreamAction, + body: retryWrappedBody, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + accountRepo: s.accountRepo, + handleError: s.handleUpstreamError, + requestedModel: originalModel, + isStickySession: isStickySession, + groupID: 0, + sessionHash: "", + }) + if retryErr == nil { + retryResp := retryResult.resp + if retryResp.StatusCode < 400 { + resp = retryResp + } else { + retryRespBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + _ = retryResp.Body.Close() + retryOpsBody := retryRespBody + if retryUnwrapped, unwrapErr := s.unwrapV1InternalResponse(retryRespBody); unwrapErr == nil && len(retryUnwrapped) > 0 { + retryOpsBody = retryUnwrapped + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: retryResp.StatusCode, + UpstreamRequestID: retryResp.Header.Get("x-request-id"), + Kind: "signature_retry", + Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(retryOpsBody))), + Detail: s.getUpstreamErrorDetail(retryOpsBody), + }) + respBody = retryRespBody + resp = &http.Response{ + StatusCode: retryResp.StatusCode, + Header: retryResp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(retryRespBody)), + } + contentType = resp.Header.Get("Content-Type") + } + } else { + if switchErr, ok := IsAntigravityAccountSwitchError(retryErr); ok { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: http.StatusServiceUnavailable, + Kind: "failover", + Message: sanitizeUpstreamErrorMessage(retryErr.Error()), + }) + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusServiceUnavailable, + ForceCacheBilling: switchErr.IsStickySession, + } + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "signature_retry_request_error", + Message: sanitizeUpstreamErrorMessage(retryErr.Error()), + }) + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry request failed: %v", account.ID, retryErr) + } + } else { + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry wrap failed: %v", account.ID, wrapErr) + } + } + // fallback 成功:继续按正常响应处理 if resp.StatusCode < 400 { goto handleSuccess @@ -3696,6 +3922,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context finalEvents, agUsage := processor.Finish() if len(finalEvents) > 0 { cw.Write(finalEvents) + } else if !processor.MessageStartSent() && !cw.Disconnected() { + // 整个流未收到任何可解析的上游数据(全部 SSE 行均无法被 JSON 解析), + // 触发 failover 在同账号重试,避免向客户端发出缺少 message_start 的残缺流 + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Claude-Stream] empty stream response (no valid events parsed), triggering failover") + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: []byte(`{"error":"empty stream response from upstream"}`), + RetryableOnSameAccount: true, + } } return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil } diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 84b65adc..6e0a7305 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -134,6 +134,47 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, return s.resp, s.err } +type queuedHTTPUpstreamStub struct { + responses []*http.Response + errors []error + requestBodies [][]byte + callCount int + onCall func(*http.Request, *queuedHTTPUpstreamStub) +} + +func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + if req != nil && req.Body != nil { + body, _ := io.ReadAll(req.Body) + s.requestBodies = append(s.requestBodies, body) + req.Body = io.NopCloser(bytes.NewReader(body)) + } else { + s.requestBodies = append(s.requestBodies, nil) + } + + idx := s.callCount + s.callCount++ + if s.onCall != nil { + s.onCall(req, s) + } + + var resp *http.Response + if idx < len(s.responses) { + resp = s.responses[idx] + } + var err error + if idx < len(s.errors) { + err = s.errors[idx] + } + if resp == nil && err == nil { + return nil, errors.New("unexpected upstream call") + } + return resp, err +} + +func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) { + return s.Do(req, proxyURL, accountID, concurrency) +} + type antigravitySettingRepoStub struct{} func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { @@ -556,6 +597,177 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing require.Equal(t, mappedModel, result.Model) } +func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hello"}}}, + {"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}}, + {"role": "model", "parts": []map[string]any{{"functionCall": map[string]any{"name": "toolA", "args": map[string]any{"x": 1}}, "thoughtSignature": "sig_bad_2"}}}, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body)) + c.Request = req + + firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`) + secondRespBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") + + upstream := &queuedHTTPUpstreamStub{ + responses: []*http.Response{ + { + StatusCode: http.StatusBadRequest, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"req-sig-1"}, + }, + Body: io.NopCloser(bytes.NewReader(firstRespBody)), + }, + { + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req-sig-2"}, + }, + Body: io.NopCloser(bytes.NewReader(secondRespBody)), + }, + }, + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: upstream, + } + + const originalModel = "gemini-3.1-pro-preview" + const mappedModel = "gemini-3.1-pro-high" + account := &Account{ + ID: 7, + Name: "acc-gemini-signature", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + "model_mapping": map[string]any{ + originalModel: mappedModel, + }, + }, + } + + result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, mappedModel, result.Model) + require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry") + + firstReq := string(upstream.requestBodies[0]) + secondReq := string(upstream.requestBodies[1]) + require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_1"`) + require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_2"`) + require.Contains(t, secondReq, `"thoughtSignature":"skip_thought_signature_validator"`) + require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_1"`) + require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_2"`) + + raw, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := raw.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.NotEmpty(t, events) + require.Equal(t, "signature_error", events[0].Kind) +} + +func TestAntigravityGatewayService_ForwardGemini_SignatureRetryPropagatesFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hello"}}}, + {"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}}, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body)) + c.Request = req + + firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`) + + const originalModel = "gemini-3.1-pro-preview" + const mappedModel = "gemini-3.1-pro-high" + account := &Account{ + ID: 8, + Name: "acc-gemini-signature-failover", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + "model_mapping": map[string]any{ + originalModel: mappedModel, + }, + }, + } + + upstream := &queuedHTTPUpstreamStub{ + responses: []*http.Response{ + { + StatusCode: http.StatusBadRequest, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"req-sig-failover-1"}, + }, + Body: io.NopCloser(bytes.NewReader(firstRespBody)), + }, + }, + onCall: func(_ *http.Request, stub *queuedHTTPUpstreamStub) { + if stub.callCount != 1 { + return + } + futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339) + account.Extra = map[string]any{ + modelRateLimitsKey: map[string]any{ + mappedModel: map[string]any{ + "rate_limit_reset_at": futureResetAt, + }, + }, + } + }, + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: upstream, + } + + result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, true) + require.Nil(t, result) + + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr, "signature retry should propagate failover instead of falling back to the original 400") + require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode) + require.True(t, failoverErr.ForceCacheBilling) + require.Len(t, upstream.requestBodies, 1, "retry should stop at preflight failover and not issue a second upstream request") + + raw, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := raw.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 2) + require.Equal(t, "signature_error", events[0].Kind) + require.Equal(t, "failover", events[1].Kind) +} + // TestStreamUpstreamResponse_UsageAndFirstToken // 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间 func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { @@ -998,6 +1210,46 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) { require.True(t, result.clientDisconnect) } +// TestHandleClaudeStreamingResponse_EmptyStream +// 验证:上游只返回无法解析的 SSE 行时,触发 UpstreamFailoverError 而不是向客户端发出残缺流 +func TestHandleClaudeStreamingResponse_EmptyStream(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // 所有行均为无法 JSON 解析的内容,ProcessLine 全部返回 nil + fmt.Fprintln(pw, "data: not-valid-json") + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, "data: also-invalid") + fmt.Fprintln(pw, "") + }() + + _, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + _ = pr.Close() + + // 应当返回 UpstreamFailoverError 而非 nil,以便上层触发 failover + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.True(t, failoverErr.RetryableOnSameAccount) + + // 客户端不应收到任何 SSE 事件(既无 message_start 也无 message_stop) + body := rec.Body.String() + require.NotContains(t, body, "event: message_start") + require.NotContains(t, body, "event: message_stop") + require.NotContains(t, body, "event: message_delta") +} + // TestHandleClaudeStreamingResponse_ContextCanceled // 验证:context 取消时不注入错误事件 func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) { diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index b67c7faf..5f6691be 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -112,7 +112,10 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig } } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } // 交换 token tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier) @@ -167,7 +170,10 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken time.Sleep(backoff) } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } tokenResp, err := client.RefreshToken(ctx, refreshToken) if err == nil { now := time.Now() @@ -209,7 +215,10 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr } // 获取用户信息(email) - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken) if err != nil { fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err) @@ -309,7 +318,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac time.Sleep(backoff) } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return "", fmt.Errorf("create antigravity client failed: %w", err) + } loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken) if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go index 07eb563d..9e09c904 100644 --- a/backend/internal/service/antigravity_quota_fetcher.go +++ b/backend/internal/service/antigravity_quota_fetcher.go @@ -2,11 +2,29 @@ package service import ( "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "regexp" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) +const ( + forbiddenTypeValidation = "validation" + forbiddenTypeViolation = "violation" + forbiddenTypeForbidden = "forbidden" + + // 机器可读的错误码 + errorCodeForbidden = "forbidden" + errorCodeUnauthenticated = "unauthenticated" + errorCodeRateLimited = "rate_limited" + errorCodeNetworkError = "network_error" +) + // AntigravityQuotaFetcher 从 Antigravity API 获取额度 type AntigravityQuotaFetcher struct { proxyRepo ProxyRepository @@ -31,16 +49,40 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou accessToken := account.GetCredential("access_token") projectID := account.GetCredential("project_id") - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } // 调用 API 获取配额 modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) if err != nil { + // 403 Forbidden: 不报错,返回 is_forbidden 标记 + var forbiddenErr *antigravity.ForbiddenError + if errors.As(err, &forbiddenErr) { + now := time.Now() + fbType := classifyForbiddenType(forbiddenErr.Body) + return &QuotaResult{ + UsageInfo: &UsageInfo{ + UpdatedAt: &now, + IsForbidden: true, + ForbiddenReason: forbiddenErr.Body, + ForbiddenType: fbType, + ValidationURL: extractValidationURL(forbiddenErr.Body), + NeedsVerify: fbType == forbiddenTypeValidation, + IsBanned: fbType == forbiddenTypeViolation, + ErrorCode: errorCodeForbidden, + }, + }, nil + } return nil, err } + // 调用 LoadCodeAssist 获取订阅等级和 AI Credits 余额(非关键路径,失败不影响主流程) + tierRaw, tierNormalized, loadResp := f.fetchSubscriptionTier(ctx, client, accessToken) + // 转换为 UsageInfo - usageInfo := f.buildUsageInfo(modelsResp) + usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized, loadResp) return &QuotaResult{ UsageInfo: usageInfo, @@ -48,15 +90,53 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou }, nil } -// buildUsageInfo 将 API 响应转换为 UsageInfo -func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo { - now := time.Now() - info := &UsageInfo{ - UpdatedAt: &now, - AntigravityQuota: make(map[string]*AntigravityModelQuota), +// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串。 +// 同时返回 LoadCodeAssistResponse,以便提取 AI Credits 余额。 +func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string, loadResp *antigravity.LoadCodeAssistResponse) { + loadResp, _, err := client.LoadCodeAssist(ctx, accessToken) + if err != nil { + slog.Warn("failed to fetch subscription tier", "error", err) + return "", "", nil + } + if loadResp == nil { + return "", "", nil } - // 遍历所有模型,填充 AntigravityQuota + raw = loadResp.GetTier() // 已有方法:paidTier > currentTier + normalized = normalizeTier(raw) + return raw, normalized, loadResp +} + +// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN +func normalizeTier(raw string) string { + if raw == "" { + return "" + } + lower := strings.ToLower(raw) + switch { + case strings.Contains(lower, "ultra"): + return "ULTRA" + case strings.Contains(lower, "pro"): + return "PRO" + case strings.Contains(lower, "free"): + return "FREE" + default: + return "UNKNOWN" + } +} + +// buildUsageInfo 将 API 响应转换为 UsageInfo。 +func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string, loadResp *antigravity.LoadCodeAssistResponse) *UsageInfo { + now := time.Now() + info := &UsageInfo{ + UpdatedAt: &now, + AntigravityQuota: make(map[string]*AntigravityModelQuota), + AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail), + SubscriptionTier: tierNormalized, + SubscriptionTierRaw: tierRaw, + } + + // 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails for modelName, modelInfo := range modelsResp.Models { if modelInfo.QuotaInfo == nil { continue @@ -69,6 +149,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv Utilization: utilization, ResetTime: modelInfo.QuotaInfo.ResetTime, } + + // 填充模型详细能力信息 + detail := &AntigravityModelDetail{ + DisplayName: modelInfo.DisplayName, + SupportsImages: modelInfo.SupportsImages, + SupportsThinking: modelInfo.SupportsThinking, + ThinkingBudget: modelInfo.ThinkingBudget, + Recommended: modelInfo.Recommended, + MaxTokens: modelInfo.MaxTokens, + MaxOutputTokens: modelInfo.MaxOutputTokens, + SupportedMimeTypes: modelInfo.SupportedMimeTypes, + } + info.AntigravityQuotaDetails[modelName] = detail + } + + // 废弃模型转发规则 + if len(modelsResp.DeprecatedModelIDs) > 0 { + info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs)) + for oldID, deprecated := range modelsResp.DeprecatedModelIDs { + info.ModelForwardingRules[oldID] = deprecated.NewModelID + } } // 同时设置 FiveHour 用于兼容展示(取主要模型) @@ -90,6 +191,16 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv } } + if loadResp != nil { + for _, credit := range loadResp.GetAvailableCredits() { + info.AICredits = append(info.AICredits, AICredit{ + CreditType: credit.CreditType, + Amount: credit.GetAmount(), + MinimumBalance: credit.GetMinimumAmount(), + }) + } + } + return info } @@ -104,3 +215,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco } return proxy.URL() } + +// classifyForbiddenType 根据 403 响应体判断禁止类型 +func classifyForbiddenType(body string) string { + lower := strings.ToLower(body) + switch { + case strings.Contains(lower, "validation_required") || + strings.Contains(lower, "verify your account") || + strings.Contains(lower, "validation_url"): + return forbiddenTypeValidation + case strings.Contains(lower, "terms of service") || + strings.Contains(lower, "violation"): + return forbiddenTypeViolation + default: + return forbiddenTypeForbidden + } +} + +// urlPattern 用于从 403 响应体中提取 URL(降级方案) +var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`) + +// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接 +func extractValidationURL(body string) string { + // 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url + var parsed struct { + Error struct { + Details []struct { + Metadata map[string]string `json:"metadata"` + } `json:"details"` + } `json:"error"` + } + if json.Unmarshal([]byte(body), &parsed) == nil { + for _, detail := range parsed.Error.Details { + if u := detail.Metadata["validation_url"]; u != "" { + return u + } + if u := detail.Metadata["appeal_url"]; u != "" { + return u + } + } + } + + // 2. 降级:正则匹配 URL + lower := strings.ToLower(body) + if !strings.Contains(lower, "validation") && + !strings.Contains(lower, "verify") && + !strings.Contains(lower, "appeal") { + return "" + } + // 先解码常见转义再匹配 + normalized := strings.ReplaceAll(body, `\u0026`, "&") + if m := urlPattern.FindString(normalized); m != "" { + return m + } + return "" +} diff --git a/backend/internal/service/antigravity_quota_fetcher_test.go b/backend/internal/service/antigravity_quota_fetcher_test.go new file mode 100644 index 00000000..e0f57051 --- /dev/null +++ b/backend/internal/service/antigravity_quota_fetcher_test.go @@ -0,0 +1,522 @@ +//go:build unit + +package service + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" +) + +// --------------------------------------------------------------------------- +// normalizeTier +// --------------------------------------------------------------------------- + +func TestNormalizeTier(t *testing.T) { + tests := []struct { + name string + raw string + expected string + }{ + {name: "empty string", raw: "", expected: ""}, + {name: "free-tier", raw: "free-tier", expected: "FREE"}, + {name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"}, + {name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"}, + {name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"}, + {name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"}, + {name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"}, + {name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"}, + {name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeTier(tt.raw) + require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw) + }) + } +} + +// --------------------------------------------------------------------------- +// buildUsageInfo +// --------------------------------------------------------------------------- + +func aqfBoolPtr(v bool) *bool { return &v } +func aqfIntPtr(v int) *int { return &v } + +func TestBuildUsageInfo_BasicModels(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.75, + ResetTime: "2026-03-08T12:00:00Z", + }, + DisplayName: "Claude Sonnet 4", + SupportsImages: aqfBoolPtr(true), + SupportsThinking: aqfBoolPtr(false), + ThinkingBudget: aqfIntPtr(0), + Recommended: aqfBoolPtr(true), + MaxTokens: aqfIntPtr(200000), + MaxOutputTokens: aqfIntPtr(16384), + SupportedMimeTypes: map[string]bool{ + "image/png": true, + "image/jpeg": true, + }, + }, + "gemini-2.5-pro": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.50, + ResetTime: "2026-03-08T15:00:00Z", + }, + DisplayName: "Gemini 2.5 Pro", + MaxTokens: aqfIntPtr(1000000), + MaxOutputTokens: aqfIntPtr(65536), + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", nil) + + // 基本字段 + require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set") + require.Equal(t, "PRO", info.SubscriptionTier) + require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw) + + // AntigravityQuota + require.Len(t, info.AntigravityQuota, 2) + + sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"] + require.NotNil(t, sonnetQuota) + require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25 + require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime) + + geminiQuota := info.AntigravityQuota["gemini-2.5-pro"] + require.NotNil(t, geminiQuota) + require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50 + require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime) + + // AntigravityQuotaDetails + require.Len(t, info.AntigravityQuotaDetails, 2) + + sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"] + require.NotNil(t, sonnetDetail) + require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName) + require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages) + require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking) + require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget) + require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended) + require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens) + require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens) + require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes) + + geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"] + require.NotNil(t, geminiDetail) + require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName) + require.Nil(t, geminiDetail.SupportsImages) + require.Nil(t, geminiDetail.SupportsThinking) + require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens) + require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens) +} + +func TestBuildUsageInfo_DeprecatedModels(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 1.0, + }, + }, + }, + DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{ + "claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"}, + "claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"}, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "", nil) + + require.Len(t, info.ModelForwardingRules, 2) + require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"]) + require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"]) +} + +func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "some-model": { + QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9}, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "", nil) + + require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models") +} + +func TestBuildUsageInfo_EmptyModels(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{}, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "", nil) + + require.NotNil(t, info) + require.NotNil(t, info.AntigravityQuota) + require.Empty(t, info.AntigravityQuota) + require.NotNil(t, info.AntigravityQuotaDetails) + require.Empty(t, info.AntigravityQuotaDetails) + require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists") +} + +func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "model-without-quota": { + DisplayName: "No Quota Model", + // QuotaInfo is nil + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "", nil) + + require.NotNil(t, info) + require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped") + require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too") +} + +func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + // priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"] + // When the first priority model exists, it should be used for FiveHour + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "gemini-2.5-pro": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.40, + ResetTime: "2026-03-08T18:00:00Z", + }, + }, + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.80, + ResetTime: "2026-03-08T12:00:00Z", + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "", nil) + + require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists") + // claude-sonnet-4-20250514 is first in priority list, so it should be used + expectedUtilization := (1.0 - 0.80) * 100 // 20 + require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01) + require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime") +} + +func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + // Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514 + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.60, + ResetTime: "2026-03-08T14:00:00Z", + }, + }, + "gemini-2.5-pro": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.30, + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "", nil) + + require.NotNil(t, info.FiveHour) + expectedUtilization := (1.0 - 0.60) * 100 // 40 + require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01) +} + +func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + // Only gemini-2.5-pro exists (third in priority list) + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "gemini-2.5-pro": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.30, + }, + }, + "other-model": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.90, + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "", nil) + + require.NotNil(t, info.FiveHour) + expectedUtilization := (1.0 - 0.30) * 100 // 70 + require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01) +} + +func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + // None of the priority models exist + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "some-other-model": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.50, + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "", nil) + + require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists") +} + +func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.50, + ResetTime: "", // empty reset time + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "", nil) + + require.NotNil(t, info.FiveHour) + require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty") + require.Equal(t, 0, info.FiveHour.RemainingSeconds) +} + +func TestBuildUsageInfo_FullUtilization(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.0, // fully used + ResetTime: "2026-03-08T12:00:00Z", + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "", nil) + + quota := info.AntigravityQuota["claude-sonnet-4-20250514"] + require.NotNil(t, quota) + require.Equal(t, 100, quota.Utilization) +} + +func TestBuildUsageInfo_ZeroUtilization(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 1.0, // fully available + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "", nil) + quota := info.AntigravityQuota["claude-sonnet-4-20250514"] + require.NotNil(t, quota) + require.Equal(t, 0, quota.Utilization) +} + +func TestBuildUsageInfo_AICredits(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{}, + } + loadResp := &antigravity.LoadCodeAssistResponse{ + PaidTier: &antigravity.PaidTierInfo{ + ID: "g1-pro-tier", + AvailableCredits: []antigravity.AvailableCredit{ + { + CreditType: "GOOGLE_ONE_AI", + CreditAmount: "25", + MinimumCreditAmountForUsage: "5", + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", loadResp) + + require.Len(t, info.AICredits, 1) + require.Equal(t, "GOOGLE_ONE_AI", info.AICredits[0].CreditType) + require.Equal(t, 25.0, info.AICredits[0].Amount) + require.Equal(t, 5.0, info.AICredits[0].MinimumBalance) +} + +func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) { + // 模拟 FetchQuota 遇到 403 时的行为: + // FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true + forbiddenErr := &antigravity.ForbiddenError{ + StatusCode: 403, + Body: "Access denied", + } + + // 验证 ForbiddenError 满足 errors.As + var target *antigravity.ForbiddenError + require.True(t, errors.As(forbiddenErr, &target)) + require.Equal(t, 403, target.StatusCode) + require.Equal(t, "Access denied", target.Body) + require.Contains(t, forbiddenErr.Error(), "403") +} + +// --------------------------------------------------------------------------- +// classifyForbiddenType +// --------------------------------------------------------------------------- + +func TestClassifyForbiddenType(t *testing.T) { + tests := []struct { + name string + body string + expected string + }{ + { + name: "VALIDATION_REQUIRED keyword", + body: `{"error":{"message":"VALIDATION_REQUIRED"}}`, + expected: "validation", + }, + { + name: "verify your account", + body: `Please verify your account to continue`, + expected: "validation", + }, + { + name: "contains validation_url field", + body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`, + expected: "validation", + }, + { + name: "terms of service violation", + body: `Your account has been suspended for Terms of Service violation`, + expected: "violation", + }, + { + name: "violation keyword", + body: `Account suspended due to policy violation`, + expected: "violation", + }, + { + name: "generic 403", + body: `Access denied`, + expected: "forbidden", + }, + { + name: "empty body", + body: "", + expected: "forbidden", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyForbiddenType(tt.body) + require.Equal(t, tt.expected, got) + }) + } +} + +// --------------------------------------------------------------------------- +// extractValidationURL +// --------------------------------------------------------------------------- + +func TestExtractValidationURL(t *testing.T) { + tests := []struct { + name string + body string + expected string + }{ + { + name: "structured validation_url", + body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`, + expected: "https://accounts.google.com/verify?token=abc", + }, + { + name: "structured appeal_url", + body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`, + expected: "https://support.google.com/appeal/123", + }, + { + name: "validation_url takes priority over appeal_url", + body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`, + expected: "https://v.com", + }, + { + name: "fallback regex with verify keyword", + body: `Please verify your account at https://accounts.google.com/verify`, + expected: "https://accounts.google.com/verify", + }, + { + name: "no URL in generic forbidden", + body: `Access denied`, + expected: "", + }, + { + name: "empty body", + body: "", + expected: "", + }, + { + name: "URL present but no validation keywords", + body: `Error at https://example.com/something`, + expected: "", + }, + { + name: "unicode escaped ampersand", + body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`, + expected: "https://accounts.google.com/verify?a=1&b=2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractValidationURL(tt.body) + require.Equal(t, tt.expected, got) + }) + } +} diff --git a/backend/internal/service/antigravity_quota_scope.go b/backend/internal/service/antigravity_quota_scope.go index e181e7f8..b536d16c 100644 --- a/backend/internal/service/antigravity_quota_scope.go +++ b/backend/internal/service/antigravity_quota_scope.go @@ -32,6 +32,10 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste return false } if a.isModelRateLimitedWithContext(ctx, requestedModel) { + // Antigravity + overages 启用 + 积分未耗尽 → 放行(有积分可用) + if a.Platform == PlatformAntigravity && a.IsOveragesEnabled() && !a.isCreditsExhausted() { + return true + } return false } return true diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index dd8dd83f..df1ce9b9 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -76,10 +76,16 @@ type modelRateLimitCall struct { resetAt time.Time } +type extraUpdateCall struct { + accountID int64 + updates map[string]any +} + type stubAntigravityAccountRepo struct { AccountRepository rateCalls []rateLimitCall modelRateLimitCalls []modelRateLimitCall + extraUpdateCalls []extraUpdateCall } func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { @@ -92,6 +98,11 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i return nil } +func (s *stubAntigravityAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { + s.extraUpdateCalls = append(s.extraUpdateCalls, extraUpdateCall{accountID: id, updates: updates}) + return nil +} + func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) { t.Setenv(antigravityForwardBaseURLEnv, "") diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go index 432c80e5..f569219f 100644 --- a/backend/internal/service/antigravity_smart_retry_test.go +++ b/backend/internal/service/antigravity_smart_retry_test.go @@ -32,15 +32,23 @@ func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID // mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream type mockSmartRetryUpstream struct { - responses []*http.Response - errors []error - callIdx int - calls []string + responses []*http.Response + errors []error + callIdx int + calls []string + requestBodies [][]byte } func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { idx := m.callIdx m.calls = append(m.calls, req.URL.String()) + if req != nil && req.Body != nil { + body, _ := io.ReadAll(req.Body) + m.requestBodies = append(m.requestBodies, body) + req.Body = io.NopCloser(bytes.NewReader(body)) + } else { + m.requestBodies = append(m.requestBodies, nil) + } m.callIdx++ if idx < len(m.responses) { return m.responses[idx], m.errors[idx] diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index 068d6a08..9cdc49aa 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -3,7 +3,6 @@ package service import ( "context" "errors" - "log" "log/slog" "strconv" "strings" @@ -17,15 +16,18 @@ const ( antigravityBackfillCooldown = 5 * time.Minute ) -// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) +// AntigravityTokenCache token cache interface. type AntigravityTokenCache = GeminiTokenCache -// AntigravityTokenProvider 管理 Antigravity 账户的 access_token +// AntigravityTokenProvider manages access_token for antigravity accounts. type AntigravityTokenProvider struct { accountRepo AccountRepository tokenCache AntigravityTokenCache antigravityOAuthService *AntigravityOAuthService - backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time + backfillCooldown sync.Map // key: accountID -> last attempt time + refreshAPI *OAuthRefreshAPI + executor OAuthRefreshExecutor + refreshPolicy ProviderRefreshPolicy } func NewAntigravityTokenProvider( @@ -37,10 +39,22 @@ func NewAntigravityTokenProvider( accountRepo: accountRepo, tokenCache: tokenCache, antigravityOAuthService: antigravityOAuthService, + refreshPolicy: AntigravityProviderRefreshPolicy(), } } -// GetAccessToken 获取有效的 access_token +// SetRefreshAPI injects unified OAuth refresh API and executor. +func (p *AntigravityTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) { + p.refreshAPI = api + p.executor = executor +} + +// SetRefreshPolicy injects caller-side refresh policy. +func (p *AntigravityTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) { + p.refreshPolicy = policy +} + +// GetAccessToken returns a valid access_token. func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { if account == nil { return "", errors.New("account is nil") @@ -48,7 +62,8 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * if account.Platform != PlatformAntigravity { return "", errors.New("not an antigravity account") } - // upstream 类型:直接从 credentials 读取 api_key,不走 OAuth 刷新流程 + + // upstream accounts use static api_key and never refresh oauth token. if account.Type == AccountTypeUpstream { apiKey := account.GetCredential("api_key") if apiKey == "" { @@ -62,46 +77,38 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * cacheKey := AntigravityTokenCacheKey(account) - // 1. 先尝试缓存 + // 1) Try cache first. if p.tokenCache != nil { if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { return token, nil } } - // 2. 如果即将过期则刷新 + // 2) Refresh if needed (pre-expiry skew). expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew - if needsRefresh && p.tokenCache != nil { + if needsRefresh && p.refreshAPI != nil && p.executor != nil { + result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, antigravityTokenRefreshSkew) + if err != nil { + if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { + return "", err + } + } else if result.LockHeld { + if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil { + if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + // default policy: continue with existing token. + } else { + account = result.Account + expiresAt = account.GetCredentialAsTime("expires_at") + } + } else if needsRefresh && p.tokenCache != nil { + // Backward-compatible test path when refreshAPI is not injected. locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) if err == nil && locked { defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() - - // 拿到锁后再次检查缓存(另一个 worker 可能已刷新) - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { - return token, nil - } - - // 从数据库获取最新账户信息 - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh - } - expiresAt = account.GetCredentialAsTime("expires_at") - if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew { - if p.antigravityOAuthService == nil { - return "", errors.New("antigravity oauth service not configured") - } - tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account) - if err != nil { - return "", err - } - p.mergeCredentials(account, tokenInfo) - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr) - } - expiresAt = account.GetCredentialAsTime("expires_at") - } } } @@ -110,32 +117,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return "", errors.New("access_token not found in credentials") } - // 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现 - // "Invalid project resource name projects/"。 - // 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。 + // Backfill project_id online when missing, with cooldown to avoid hammering. if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil { if p.shouldAttemptBackfill(account.ID) { p.markBackfillAttempted(account.ID) if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" { account.Credentials["project_id"] = projectID if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr) + slog.Warn("antigravity_project_id_backfill_persist_failed", + "account_id", account.ID, + "error", updateErr, + ) } } } } - // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) + // 3) Populate cache with TTL. if p.tokenCache != nil { latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) if isStale && latestAccount != nil { - // 版本过时,使用 DB 中的最新 token slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID) accessToken = latestAccount.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found after version check") } - // 不写入缓存,让下次请求重新处理 } else { ttl := 30 * time.Minute if expiresAt != nil { @@ -156,18 +162,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return accessToken, nil } -// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段 -func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) { - newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } - account.Credentials = newCredentials -} - -// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试) +// shouldAttemptBackfill checks backfill cooldown. func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool { if v, ok := p.backfillCooldown.Load(accountID); ok { if lastAttempt, ok := v.(time.Time); ok { diff --git a/backend/internal/service/antigravity_token_refresher.go b/backend/internal/service/antigravity_token_refresher.go index e33f88d0..7ce0ccf0 100644 --- a/backend/internal/service/antigravity_token_refresher.go +++ b/backend/internal/service/antigravity_token_refresher.go @@ -25,6 +25,11 @@ func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthServi } } +// CacheKey 返回用于分布式锁的缓存键 +func (r *AntigravityTokenRefresher) CacheKey(account *Account) string { + return AntigravityTokenCacheKey(account) +} + // CanRefresh 检查是否可以刷新此账户 func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool { return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth @@ -58,11 +63,7 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo) // 合并旧的 credentials,保留新 credentials 中不存在的字段 - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } + newCredentials = MergeCredentials(account.Credentials, newCredentials) // 特殊处理 project_id:如果新值为空但旧值非空,保留旧值 // 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失 diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index 07523597..ec20b0a9 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -14,6 +14,19 @@ const ( StatusAPIKeyExpired = "expired" ) +// Rate limit window durations +const ( + RateLimitWindow5h = 5 * time.Hour + RateLimitWindow1d = 24 * time.Hour + RateLimitWindow7d = 7 * 24 * time.Hour +) + +// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration. +// A nil windowStart is treated as expired — no initialized window means any accumulated usage is stale. +func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool { + return windowStart == nil || time.Since(*windowStart) >= duration +} + type APIKey struct { ID int64 UserID int64 @@ -36,12 +49,28 @@ type APIKey struct { Quota float64 // Quota limit in USD (0 = unlimited) QuotaUsed float64 // Used quota amount ExpiresAt *time.Time // Expiration time (nil = never expires) + + // Rate limit fields + RateLimit5h float64 // Rate limit in USD per 5h (0 = unlimited) + RateLimit1d float64 // Rate limit in USD per 1d (0 = unlimited) + RateLimit7d float64 // Rate limit in USD per 7d (0 = unlimited) + Usage5h float64 // Used amount in current 5h window + Usage1d float64 // Used amount in current 1d window + Usage7d float64 // Used amount in current 7d window + Window5hStart *time.Time // Start of current 5h window + Window1dStart *time.Time // Start of current 1d window + Window7dStart *time.Time // Start of current 7d window } func (k *APIKey) IsActive() bool { return k.Status == StatusActive } +// HasRateLimits returns true if any rate limit window is configured +func (k *APIKey) HasRateLimits() bool { + return k.RateLimit5h > 0 || k.RateLimit1d > 0 || k.RateLimit7d > 0 +} + // IsExpired checks if the API key has expired func (k *APIKey) IsExpired() bool { if k.ExpiresAt == nil { @@ -81,3 +110,34 @@ func (k *APIKey) GetDaysUntilExpiry() int { } return int(duration.Hours() / 24) } + +// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired. +func (k *APIKey) EffectiveUsage5h() float64 { + if IsWindowExpired(k.Window5hStart, RateLimitWindow5h) { + return 0 + } + return k.Usage5h +} + +// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired. +func (k *APIKey) EffectiveUsage1d() float64 { + if IsWindowExpired(k.Window1dStart, RateLimitWindow1d) { + return 0 + } + return k.Usage1d +} + +// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired. +func (k *APIKey) EffectiveUsage7d() float64 { + if IsWindowExpired(k.Window7dStart, RateLimitWindow7d) { + return 0 + } + return k.Usage7d +} + +// APIKeyListFilters holds optional filtering parameters for listing API keys. +type APIKeyListFilters struct { + Search string + Status string + GroupID *int64 // nil=不筛选, 0=无分组, >0=指定分组 +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 4240be23..e8ad5c9c 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -19,6 +19,11 @@ type APIKeyAuthSnapshot struct { // Expiration field for API Key expiration feature ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires) + + // Rate limit configuration (only limits, not usage - usage read from Redis at check time) + RateLimit5h float64 `json:"rate_limit_5h"` + RateLimit1d float64 `json:"rate_limit_1d"` + RateLimit7d float64 `json:"rate_limit_7d"` } // APIKeyAuthUserSnapshot 用户快照 @@ -60,6 +65,10 @@ type APIKeyAuthGroupSnapshot struct { // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` + + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + DefaultMappedModel string `json:"default_mapped_model,omitempty"` } // APIKeyAuthCacheEntry 缓存条目,支持负缓存 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 30eb8d74..f727ab10 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -209,6 +209,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { Quota: apiKey.Quota, QuotaUsed: apiKey.QuotaUsed, ExpiresAt: apiKey.ExpiresAt, + RateLimit5h: apiKey.RateLimit5h, + RateLimit1d: apiKey.RateLimit1d, + RateLimit7d: apiKey.RateLimit7d, User: APIKeyAuthUserSnapshot{ ID: apiKey.User.ID, Status: apiKey.User.Status, @@ -242,6 +245,8 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, MCPXMLInject: apiKey.Group.MCPXMLInject, SupportedModelScopes: apiKey.Group.SupportedModelScopes, + AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, + DefaultMappedModel: apiKey.Group.DefaultMappedModel, } } return snapshot @@ -262,6 +267,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho Quota: snapshot.Quota, QuotaUsed: snapshot.QuotaUsed, ExpiresAt: snapshot.ExpiresAt, + RateLimit5h: snapshot.RateLimit5h, + RateLimit1d: snapshot.RateLimit1d, + RateLimit7d: snapshot.RateLimit7d, User: &User{ ID: snapshot.User.ID, Status: snapshot.User.Status, @@ -296,6 +304,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, MCPXMLInject: snapshot.Group.MCPXMLInject, SupportedModelScopes: snapshot.Group.SupportedModelScopes, + AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, + DefaultMappedModel: snapshot.Group.DefaultMappedModel, } } s.compileAPIKeyIPRules(apiKey) diff --git a/backend/internal/service/api_key_rate_limit_test.go b/backend/internal/service/api_key_rate_limit_test.go new file mode 100644 index 00000000..4058ca4b --- /dev/null +++ b/backend/internal/service/api_key_rate_limit_test.go @@ -0,0 +1,245 @@ +package service + +import ( + "testing" + "time" +) + +func TestIsWindowExpired(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + start *time.Time + duration time.Duration + want bool + }{ + { + name: "nil window start (treated as expired)", + start: nil, + duration: RateLimitWindow5h, + want: true, + }, + { + name: "active window (started 1h ago, 5h window)", + start: rateLimitTimePtr(now.Add(-1 * time.Hour)), + duration: RateLimitWindow5h, + want: false, + }, + { + name: "expired window (started 6h ago, 5h window)", + start: rateLimitTimePtr(now.Add(-6 * time.Hour)), + duration: RateLimitWindow5h, + want: true, + }, + { + name: "exactly at boundary (started 5h ago, 5h window)", + start: rateLimitTimePtr(now.Add(-5 * time.Hour)), + duration: RateLimitWindow5h, + want: true, + }, + { + name: "active 1d window (started 12h ago)", + start: rateLimitTimePtr(now.Add(-12 * time.Hour)), + duration: RateLimitWindow1d, + want: false, + }, + { + name: "expired 1d window (started 25h ago)", + start: rateLimitTimePtr(now.Add(-25 * time.Hour)), + duration: RateLimitWindow1d, + want: true, + }, + { + name: "active 7d window (started 3d ago)", + start: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)), + duration: RateLimitWindow7d, + want: false, + }, + { + name: "expired 7d window (started 8d ago)", + start: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)), + duration: RateLimitWindow7d, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsWindowExpired(tt.start, tt.duration) + if got != tt.want { + t.Errorf("IsWindowExpired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAPIKey_EffectiveUsage(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + key APIKey + want5h float64 + want1d float64 + want7d float64 + }{ + { + name: "all windows active", + key: APIKey{ + Usage5h: 5.0, + Usage1d: 10.0, + Usage7d: 50.0, + Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)), + }, + want5h: 5.0, + want1d: 10.0, + want7d: 50.0, + }, + { + name: "all windows expired", + key: APIKey{ + Usage5h: 5.0, + Usage1d: 10.0, + Usage7d: 50.0, + Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-25 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)), + }, + want5h: 0, + want1d: 0, + want7d: 0, + }, + { + name: "nil window starts return 0 (stale usage reset)", + key: APIKey{ + Usage5h: 5.0, + Usage1d: 10.0, + Usage7d: 50.0, + Window5hStart: nil, + Window1dStart: nil, + Window7dStart: nil, + }, + want5h: 0, + want1d: 0, + want7d: 0, + }, + { + name: "mixed: 5h expired, 1d active, 7d nil", + key: APIKey{ + Usage5h: 5.0, + Usage1d: 10.0, + Usage7d: 50.0, + Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)), + Window7dStart: nil, + }, + want5h: 0, + want1d: 10.0, + want7d: 0, + }, + { + name: "zero usage with active windows", + key: APIKey{ + Usage5h: 0, + Usage1d: 0, + Usage7d: 0, + Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)), + }, + want5h: 0, + want1d: 0, + want7d: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.key.EffectiveUsage5h(); got != tt.want5h { + t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h) + } + if got := tt.key.EffectiveUsage1d(); got != tt.want1d { + t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d) + } + if got := tt.key.EffectiveUsage7d(); got != tt.want7d { + t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d) + } + }) + } +} + +func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + data APIKeyRateLimitData + want5h float64 + want1d float64 + want7d float64 + }{ + { + name: "all windows active", + data: APIKeyRateLimitData{ + Usage5h: 3.0, + Usage1d: 8.0, + Usage7d: 40.0, + Window5hStart: rateLimitTimePtr(now.Add(-2 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-10 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-2 * 24 * time.Hour)), + }, + want5h: 3.0, + want1d: 8.0, + want7d: 40.0, + }, + { + name: "all windows expired", + data: APIKeyRateLimitData{ + Usage5h: 3.0, + Usage1d: 8.0, + Usage7d: 40.0, + Window5hStart: rateLimitTimePtr(now.Add(-10 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-48 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-10 * 24 * time.Hour)), + }, + want5h: 0, + want1d: 0, + want7d: 0, + }, + { + name: "nil window starts return 0 (stale usage reset)", + data: APIKeyRateLimitData{ + Usage5h: 3.0, + Usage1d: 8.0, + Usage7d: 40.0, + Window5hStart: nil, + Window1dStart: nil, + Window7dStart: nil, + }, + want5h: 0, + want1d: 0, + want7d: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.data.EffectiveUsage5h(); got != tt.want5h { + t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h) + } + if got := tt.data.EffectiveUsage1d(); got != tt.want1d { + t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d) + } + if got := tt.data.EffectiveUsage7d(); got != tt.want7d { + t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d) + } + }) + } +} + +func rateLimitTimePtr(t time.Time) *time.Time { + return &t +} diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 0d073077..18e9ff7a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "strconv" + "strings" "sync" "time" @@ -30,6 +31,11 @@ var ( ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期") // ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted") ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完") + + // Rate limit errors + ErrAPIKeyRateLimit5hExceeded = infraerrors.TooManyRequests("API_KEY_RATE_5H_EXCEEDED", "api key 5小时限额已用完") + ErrAPIKeyRateLimit1dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_1D_EXCEEDED", "api key 日限额已用完") + ErrAPIKeyRateLimit7dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_7D_EXCEEDED", "api key 7天限额已用完") ) const ( @@ -50,7 +56,7 @@ type APIKeyRepository interface { Update(ctx context.Context, key *APIKey) error Delete(ctx context.Context, id int64) error - ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) + ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) CountByUserID(ctx context.Context, userID int64) (int64, error) ExistsByKey(ctx context.Context, key string) (bool, error) @@ -64,6 +70,54 @@ type APIKeyRepository interface { // Quota methods IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error + + // Rate limit methods + IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error + ResetRateLimitWindows(ctx context.Context, id int64) error + GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) +} + +// APIKeyRateLimitData holds rate limit usage and window state for an API key. +type APIKeyRateLimitData struct { + Usage5h float64 + Usage1d float64 + Usage7d float64 + Window5hStart *time.Time + Window1dStart *time.Time + Window7dStart *time.Time +} + +// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired. +func (d *APIKeyRateLimitData) EffectiveUsage5h() float64 { + if IsWindowExpired(d.Window5hStart, RateLimitWindow5h) { + return 0 + } + return d.Usage5h +} + +// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired. +func (d *APIKeyRateLimitData) EffectiveUsage1d() float64 { + if IsWindowExpired(d.Window1dStart, RateLimitWindow1d) { + return 0 + } + return d.Usage1d +} + +// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired. +func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 { + if IsWindowExpired(d.Window7dStart, RateLimitWindow7d) { + return 0 + } + return d.Usage7d +} + +// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update. +// It is intentionally small so repositories can return it from a single SQL statement. +type APIKeyQuotaUsageState struct { + QuotaUsed float64 + Quota float64 + Key string + Status string } // APIKeyCache defines cache operations for API key service @@ -102,6 +156,11 @@ type CreateAPIKeyRequest struct { // Quota fields Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires) + + // Rate limit fields (0 = unlimited) + RateLimit5h float64 `json:"rate_limit_5h"` + RateLimit1d float64 `json:"rate_limit_1d"` + RateLimit7d float64 `json:"rate_limit_7d"` } // UpdateAPIKeyRequest 更新API Key请求 @@ -117,22 +176,34 @@ type UpdateAPIKeyRequest struct { ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change) ClearExpiration bool `json:"-"` // Clear expiration (internal use) ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0 + + // Rate limit fields (nil = no change, 0 = unlimited) + RateLimit5h *float64 `json:"rate_limit_5h"` + RateLimit1d *float64 `json:"rate_limit_1d"` + RateLimit7d *float64 `json:"rate_limit_7d"` + ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // Reset all usage counters to 0 } // APIKeyService API Key服务 +// RateLimitCacheInvalidator invalidates rate limit cache entries on manual reset. +type RateLimitCacheInvalidator interface { + InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error +} + type APIKeyService struct { - apiKeyRepo APIKeyRepository - userRepo UserRepository - groupRepo GroupRepository - userSubRepo UserSubscriptionRepository - userGroupRateRepo UserGroupRateRepository - cache APIKeyCache - cfg *config.Config - authCacheL1 *ristretto.Cache - authCfg apiKeyAuthCacheConfig - authGroup singleflight.Group - lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time) - lastUsedTouchSF singleflight.Group + apiKeyRepo APIKeyRepository + userRepo UserRepository + groupRepo GroupRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache APIKeyCache + rateLimitCacheInvalid RateLimitCacheInvalidator // optional: invalidate Redis rate limit cache + cfg *config.Config + authCacheL1 *ristretto.Cache + authCfg apiKeyAuthCacheConfig + authGroup singleflight.Group + lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time) + lastUsedTouchSF singleflight.Group } // NewAPIKeyService 创建API Key服务实例 @@ -158,6 +229,12 @@ func NewAPIKeyService( return svc } +// SetRateLimitCacheInvalidator sets the optional rate limit cache invalidator. +// Called after construction (e.g. in wire) to avoid circular dependencies. +func (s *APIKeyService) SetRateLimitCacheInvalidator(inv RateLimitCacheInvalidator) { + s.rateLimitCacheInvalid = inv +} + func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) { if apiKey == nil { return @@ -327,6 +404,9 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK IPBlacklist: req.IPBlacklist, Quota: req.Quota, QuotaUsed: 0, + RateLimit5h: req.RateLimit5h, + RateLimit1d: req.RateLimit1d, + RateLimit7d: req.RateLimit7d, } // Set expiration time if specified @@ -346,8 +426,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK } // List 获取用户的API Key列表 -func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { - keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) +func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { + keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, filters) if err != nil { return nil, nil, fmt.Errorf("list api keys: %w", err) } @@ -519,6 +599,26 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req apiKey.IPWhitelist = req.IPWhitelist apiKey.IPBlacklist = req.IPBlacklist + // Update rate limit configuration + if req.RateLimit5h != nil { + apiKey.RateLimit5h = *req.RateLimit5h + } + if req.RateLimit1d != nil { + apiKey.RateLimit1d = *req.RateLimit1d + } + if req.RateLimit7d != nil { + apiKey.RateLimit7d = *req.RateLimit7d + } + resetRateLimit := req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage + if resetRateLimit { + apiKey.Usage5h = 0 + apiKey.Usage1d = 0 + apiKey.Usage7d = 0 + apiKey.Window5hStart = nil + apiKey.Window1dStart = nil + apiKey.Window7dStart = nil + } + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { return nil, fmt.Errorf("update api key: %w", err) } @@ -526,6 +626,11 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req s.InvalidateAuthCacheByKey(ctx, apiKey.Key) s.compileAPIKeyIPRules(apiKey) + // Invalidate Redis rate limit cache so reset takes effect immediately + if resetRateLimit && s.rateLimitCacheInvalid != nil { + _ = s.rateLimitCacheInvalid.InvalidateAPIKeyRateLimit(ctx, apiKey.ID) + } + return apiKey, nil } @@ -722,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos return nil } + type quotaStateReader interface { + IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) + } + + if repo, ok := s.apiKeyRepo.(quotaStateReader); ok { + state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost) + if err != nil { + return fmt.Errorf("increment quota used: %w", err) + } + if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" { + s.InvalidateAuthCacheByKey(ctx, state.Key) + } + return nil + } + // Use repository to atomically increment quota_used newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost) if err != nil { @@ -746,3 +866,16 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos return nil } + +// GetRateLimitData returns rate limit usage and window state for an API key. +func (s *APIKeyService) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) { + return s.apiKeyRepo.GetRateLimitData(ctx, id) +} + +// UpdateRateLimitUsage atomically increments rate limit usage counters in the DB. +func (s *APIKeyService) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { + if cost <= 0 { + return nil + } + return s.apiKeyRepo.IncrementRateLimitUsage(ctx, apiKeyID, cost) +} diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 2357813b..97b8e229 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -53,7 +53,7 @@ func (s *authRepoStub) Delete(ctx context.Context, id int64) error { panic("unexpected Delete call") } -func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { +func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { panic("unexpected ListByUserID call") } @@ -106,6 +106,15 @@ func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { panic("unexpected UpdateLastUsed call") } +func (s *authRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + panic("unexpected IncrementRateLimitUsage call") +} +func (s *authRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error { + panic("unexpected ResetRateLimitWindows call") +} +func (s *authRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) { + panic("unexpected GetRateLimitData call") +} type authCacheStub struct { getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index 79757808..dfd481e8 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -81,7 +81,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error { // 以下是接口要求实现但本测试不关心的方法 -func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { +func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { panic("unexpected ListByUserID call") } @@ -134,6 +134,18 @@ func (s *apiKeyRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt ti return nil } +func (s *apiKeyRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + panic("unexpected IncrementRateLimitUsage call") +} + +func (s *apiKeyRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error { + panic("unexpected ResetRateLimitWindows call") +} + +func (s *apiKeyRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) { + panic("unexpected GetRateLimitData call") +} + // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。 // diff --git a/backend/internal/service/api_key_service_quota_test.go b/backend/internal/service/api_key_service_quota_test.go new file mode 100644 index 00000000..2e2f6f78 --- /dev/null +++ b/backend/internal/service/api_key_service_quota_test.go @@ -0,0 +1,170 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type quotaStateRepoStub struct { + quotaBaseAPIKeyRepoStub + stateCalls int + state *APIKeyQuotaUsageState + stateErr error +} + +func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) { + s.stateCalls++ + if s.stateErr != nil { + return nil, s.stateErr + } + if s.state == nil { + return nil, nil + } + out := *s.state + return &out, nil +} + +type quotaStateCacheStub struct { + deleteAuthKeys []string +} + +func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) { + return 0, nil +} + +func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error { + return nil +} + +func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error { + return nil +} + +func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error { + return nil +} + +func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error { + return nil +} + +func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error { + return nil +} + +func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + +func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error { + return nil +} + +func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error { + return nil +} + +type quotaBaseAPIKeyRepoStub struct { + getByIDCalls int +} + +func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error { + panic("unexpected Create call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) { + s.getByIDCalls++ + return nil, nil +} +func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) { + panic("unexpected GetKeyAndOwnerID call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) { + panic("unexpected GetByKey call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) { + panic("unexpected GetByKeyForAuth call") +} +func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error { + panic("unexpected Update call") +} +func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error { + panic("unexpected Delete call") +} +func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { + panic("unexpected VerifyOwnership call") +} +func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) { + panic("unexpected CountByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) { + panic("unexpected ExistsByKey call") +} +func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) { + panic("unexpected SearchAPIKeys call") +} +func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { + panic("unexpected ClearGroupIDByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) { + panic("unexpected CountByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) { + panic("unexpected ListKeysByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) { + panic("unexpected ListKeysByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) { + panic("unexpected IncrementQuotaUsed call") +} +func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error { + panic("unexpected UpdateLastUsed call") +} +func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error { + panic("unexpected IncrementRateLimitUsage call") +} +func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error { + panic("unexpected ResetRateLimitWindows call") +} +func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) { + panic("unexpected GetRateLimitData call") +} + +func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) { + repo := "aStateRepoStub{ + state: &APIKeyQuotaUsageState{ + QuotaUsed: 12, + Quota: 10, + Key: "sk-test-quota", + Status: StatusAPIKeyQuotaExhausted, + }, + } + cache := "aStateCacheStub{} + svc := &APIKeyService{ + apiKeyRepo: repo, + cache: cache, + } + + err := svc.UpdateQuotaUsed(context.Background(), 101, 2) + require.NoError(t, err) + require.Equal(t, 1, repo.stateCalls) + require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id") + require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys) +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 9df61c44..42b6cf91 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -8,9 +8,11 @@ import ( "errors" "fmt" "net/mail" + "strconv" "strings" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -20,23 +22,25 @@ import ( ) var ( - ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") - ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") - ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") - ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") - ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") - ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") - ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired") - ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") - ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") - ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token") - ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") - ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") - ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") - ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") - ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") - ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required") - ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code") + ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") + ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") + ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") + ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") + ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") + ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") + ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired") + ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") + ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") + ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token") + ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") + ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") + ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") + ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed") + ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") + ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") + ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required") + ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code") + ErrOAuthInvitationRequired = infraerrors.Forbidden("OAUTH_INVITATION_REQUIRED", "invitation code required to complete oauth registration") ) // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 @@ -56,6 +60,7 @@ type JWTClaims struct { // AuthService 认证服务 type AuthService struct { + entClient *dbent.Client userRepo UserRepository redeemRepo RedeemCodeRepository refreshTokenCache RefreshTokenCache @@ -74,6 +79,7 @@ type DefaultSubscriptionAssigner interface { // NewAuthService 创建认证服务实例 func NewAuthService( + entClient *dbent.Client, userRepo UserRepository, redeemRepo RedeemCodeRepository, refreshTokenCache RefreshTokenCache, @@ -86,6 +92,7 @@ func NewAuthService( defaultSubAssigner DefaultSubscriptionAssigner, ) *AuthService { return &AuthService{ + entClient: entClient, userRepo: userRepo, redeemRepo: redeemRepo, refreshTokenCache: refreshTokenCache, @@ -115,6 +122,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw if isReservedEmail(email) { return "", nil, ErrEmailReserved } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return "", nil, err + } // 检查是否需要邀请码 var invitationRedeemCode *RedeemCode @@ -241,6 +251,9 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { if isReservedEmail(email) { return ErrEmailReserved } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return err + } // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) @@ -279,6 +292,9 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S if isReservedEmail(email) { return nil, ErrEmailReserved } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return nil, err + } // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) @@ -512,9 +528,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username return token, user, nil } -// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair -// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token -func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) { +// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。 +// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。 +// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。 +func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) { // 检查 refreshTokenCache 是否可用 if s.refreshTokenCache == nil { return nil, nil, errors.New("refresh token cache not configured") @@ -541,6 +558,22 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, ErrRegDisabled } + // 检查是否需要邀请码 + var invitationRedeemCode *RedeemCode + if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) { + if invitationCode == "" { + return nil, nil, ErrOAuthInvitationRequired + } + redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) + if err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + return nil, nil, ErrInvitationCodeInvalid + } + invitationRedeemCode = redeemCode + } + randomPassword, err := randomHexString(32) if err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err) @@ -568,20 +601,58 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema Status: StatusActive, } - if err := s.userRepo.Create(ctx, newUser); err != nil { - if errors.Is(err, ErrEmailExists) { - user, err = s.userRepo.GetByEmail(ctx, email) - if err != nil { - logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) + if s.entClient != nil && invitationRedeemCode != nil { + tx, err := s.entClient.Tx(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err) + return nil, nil, ErrServiceUnavailable + } + defer func() { _ = tx.Rollback() }() + txCtx := dbent.NewTxContext(ctx, tx) + + if err := s.userRepo.Create(txCtx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) return nil, nil, ErrServiceUnavailable } } else { - logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) - return nil, nil, ErrServiceUnavailable + if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + if err := tx.Commit(); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err) + return nil, nil, ErrServiceUnavailable + } + user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) } } else { - user = newUser - s.assignDefaultSubscriptions(ctx, user.ID) + if err := s.userRepo.Create(ctx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) + if invitationRedeemCode != nil { + if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + } + } } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -607,6 +678,63 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } +// pendingOAuthTokenTTL is the validity period for pending OAuth tokens. +const pendingOAuthTokenTTL = 10 * time.Minute + +// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens. +const pendingOAuthPurpose = "pending_oauth_registration" + +type pendingOAuthClaims struct { + Email string `json:"email"` + Username string `json:"username"` + Purpose string `json:"purpose"` + jwt.RegisteredClaims +} + +// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity +// while waiting for the user to supply an invitation code. +func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) { + now := time.Now() + claims := &pendingOAuthClaims{ + Email: email, + Username: username, + Purpose: pendingOAuthPurpose, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(s.cfg.JWT.Secret)) +} + +// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity. +// Returns ErrInvalidToken when the token is invalid or expired. +func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) { + if len(tokenStr) > maxTokenLength { + return "", "", ErrInvalidToken + } + parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return []byte(s.cfg.JWT.Secret), nil + }) + if parseErr != nil { + return "", "", ErrInvalidToken + } + claims, ok := token.Claims.(*pendingOAuthClaims) + if !ok || !token.Valid { + return "", "", ErrInvalidToken + } + if claims.Purpose != pendingOAuthPurpose { + return "", "", ErrInvalidToken + } + return claims.Email, claims.Username, nil +} + func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { return @@ -624,6 +752,32 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int } } +func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error { + if s.settingService == nil { + return nil + } + whitelist := s.settingService.GetRegistrationEmailSuffixWhitelist(ctx) + if !IsRegistrationEmailSuffixAllowed(email, whitelist) { + return buildEmailSuffixNotAllowedError(whitelist) + } + return nil +} + +func buildEmailSuffixNotAllowedError(whitelist []string) error { + if len(whitelist) == 0 { + return ErrEmailSuffixNotAllowed + } + + allowed := strings.Join(whitelist, ", ") + return infraerrors.BadRequest( + "EMAIL_SUFFIX_NOT_ALLOWED", + fmt.Sprintf("email suffix is not allowed, allowed suffixes: %s", allowed), + ).WithMetadata(map[string]string{ + "allowed_suffixes": strings.Join(whitelist, ","), + "allowed_suffix_count": strconv.Itoa(len(whitelist)), + }) +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 @@ -933,6 +1087,12 @@ type TokenPair struct { ExpiresIn int `json:"expires_in"` // Access Token有效期(秒) } +// TokenPairWithUser extends TokenPair with user role for backend mode checks +type TokenPairWithUser struct { + TokenPair + UserRole string +} + // GenerateTokenPair 生成Access Token和Refresh Token对 // familyID: 可选的Token家族ID,用于Token轮转时保持家族关系 func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) { @@ -1014,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami // RefreshTokenPair 使用Refresh Token刷新Token对 // 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效 -func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) { +func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPairWithUser, error) { // 检查 refreshTokenCache 是否可用 if s.refreshTokenCache == nil { return nil, ErrRefreshTokenInvalid @@ -1079,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) } // 生成新的Token对,保持同一个家族ID - return s.GenerateTokenPair(ctx, user, data.FamilyID) + pair, err := s.GenerateTokenPair(ctx, user, data.FamilyID) + if err != nil { + return nil, err + } + return &TokenPairWithUser{ + TokenPair: *pair, + UserRole: user.Role, + }, nil } // RevokeRefreshToken 撤销单个Refresh Token diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go new file mode 100644 index 00000000..0472e06c --- /dev/null +++ b/backend/internal/service/auth_service_pending_oauth_test.go @@ -0,0 +1,146 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" +) + +func newAuthServiceForPendingOAuthTest() *AuthService { + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret-pending-oauth", + ExpireHour: 1, + }, + } + return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) +} + +// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。 +func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + token, err := svc.CreatePendingOAuthToken("user@example.com", "alice") + require.NoError(t, err) + require.NotEmpty(t, token) + + email, username, err := svc.VerifyPendingOAuthToken(token) + require.NoError(t, err) + require.Equal(t, "user@example.com", email) + require.Equal(t, "alice", username) +} + +// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + // 签发一个普通 access token(JWTClaims,无 Purpose 字段) + accessToken, err := svc.GenerateToken(&User{ + ID: 1, + Email: "user@example.com", + Role: RoleUser, + }) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(accessToken) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + now := time.Now() + claims := &pendingOAuthClaims{ + Email: "user@example.com", + Username: "alice", + Purpose: "some_other_purpose", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(tokenStr) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + now := time.Now() + claims := &pendingOAuthClaims{ + Email: "user@example.com", + Username: "alice", + Purpose: "", // 旧 token 无此字段,反序列化后为零值 + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(tokenStr) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + past := time.Now().Add(-1 * time.Hour) + claims := &pendingOAuthClaims{ + Email: "user@example.com", + Username: "alice", + Purpose: pendingOAuthPurpose, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(past), + IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)), + NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(tokenStr) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) { + other := NewAuthService(nil, nil, nil, nil, &config.Config{ + JWT: config.JWTConfig{Secret: "other-secret"}, + }, nil, nil, nil, nil, nil, nil) + + token, err := other.CreatePendingOAuthToken("user@example.com", "alice") + require.NoError(t, err) + + svc := newAuthServiceForPendingOAuthTest() + _, _, err = svc.VerifyPendingOAuthToken(token) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_TooLong(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + giant := make([]byte, maxTokenLength+1) + for i := range giant { + giant[i] = 'a' + } + _, _, err := svc.VerifyPendingOAuthToken(string(giant)) + require.ErrorIs(t, err, ErrInvalidToken) +} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 1999e759..7b50e90d 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/require" ) @@ -129,6 +130,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E } return NewAuthService( + nil, // entClient repo, nil, // redeemRepo nil, // refreshTokenCache @@ -231,6 +233,51 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) { require.ErrorIs(t, err, ErrEmailReserved) } +func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`, + }, nil) + + _, _, err := service.Register(context.Background(), "user@other.com", "password") + require.ErrorIs(t, err, ErrEmailSuffixNotAllowed) + appErr := infraerrors.FromError(err) + require.Contains(t, appErr.Message, "@example.com") + require.Contains(t, appErr.Message, "@company.com") + require.Equal(t, "EMAIL_SUFFIX_NOT_ALLOWED", appErr.Reason) + require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"]) + require.Equal(t, "@example.com,@company.com", appErr.Metadata["allowed_suffixes"]) +} + +func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) { + repo := &userRepoStub{nextID: 8} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`, + }, nil) + + _, user, err := service.Register(context.Background(), "user@example.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, int64(8), user.ID) +} + +func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`, + }, nil) + + err := service.SendVerifyCode(context.Background(), "user@other.com") + require.ErrorIs(t, err, ErrEmailSuffixNotAllowed) + appErr := infraerrors.FromError(err) + require.Contains(t, appErr.Message, "@example.com") + require.Contains(t, appErr.Message, "@company.com") + require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"]) +} + func TestAuthService_Register_CreateError(t *testing.T) { repo := &userRepoStub{createErr: errors.New("create failed")} service := newAuthService(repo, map[string]string{ @@ -402,7 +449,7 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { repo := &userRepoStub{nextID: 42} assigner := &defaultSubscriptionAssignerStub{} service := newAuthService(repo, map[string]string{ - SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEnabled: "true", SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, }, nil) service.defaultSubAssigner = assigner diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go index 36cb1e06..477ba1b2 100644 --- a/backend/internal/service/auth_service_turnstile_register_test.go +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -43,6 +43,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier turnstileService := NewTurnstileService(settingService, verifier) return NewAuthService( + nil, // entClient &userRepoStub{}, nil, // redeemRepo nil, // refreshTokenCache diff --git a/backend/internal/service/backup_service.go b/backend/internal/service/backup_service.go new file mode 100644 index 00000000..25f1e9a1 --- /dev/null +++ b/backend/internal/service/backup_service.go @@ -0,0 +1,770 @@ +package service + +import ( + "compress/gzip" + "context" + "encoding/json" + "fmt" + "io" + "sort" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/robfig/cron/v3" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +const ( + settingKeyBackupS3Config = "backup_s3_config" + settingKeyBackupSchedule = "backup_schedule" + settingKeyBackupRecords = "backup_records" + + maxBackupRecords = 100 +) + +var ( + ErrBackupS3NotConfigured = infraerrors.BadRequest("BACKUP_S3_NOT_CONFIGURED", "backup S3 storage is not configured") + ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found") + ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress") + ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress") + ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted") + ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted") +) + +// ─── 接口定义 ─── + +// DBDumper abstracts database dump/restore operations +type DBDumper interface { + Dump(ctx context.Context) (io.ReadCloser, error) + Restore(ctx context.Context, data io.Reader) error +} + +// BackupObjectStore abstracts object storage for backup files +type BackupObjectStore interface { + Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error) + Download(ctx context.Context, key string) (io.ReadCloser, error) + Delete(ctx context.Context, key string) error + PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) + HeadBucket(ctx context.Context) error +} + +// BackupObjectStoreFactory creates an object store from S3 config +type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) + +// ─── 数据模型 ─── + +// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2) +type BackupS3Config struct { + Endpoint string `json:"endpoint"` // e.g. https://.r2.cloudflarestorage.com + Region string `json:"region"` // R2 用 "auto" + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key,omitempty"` //nolint:revive // field name follows AWS convention + Prefix string `json:"prefix"` // S3 key 前缀,如 "backups/" + ForcePathStyle bool `json:"force_path_style"` +} + +// IsConfigured 检查必要字段是否已配置 +func (c *BackupS3Config) IsConfigured() bool { + return c.Bucket != "" && c.AccessKeyID != "" && c.SecretAccessKey != "" +} + +// BackupScheduleConfig 定时备份配置 +type BackupScheduleConfig struct { + Enabled bool `json:"enabled"` + CronExpr string `json:"cron_expr"` // cron 表达式,如 "0 2 * * *" 每天凌晨2点 + RetainDays int `json:"retain_days"` // 备份文件过期天数,默认14,0=不自动清理 + RetainCount int `json:"retain_count"` // 最多保留份数,0=不限制 +} + +// BackupRecord 备份记录 +type BackupRecord struct { + ID string `json:"id"` + Status string `json:"status"` // pending, running, completed, failed + BackupType string `json:"backup_type"` // postgres + FileName string `json:"file_name"` + S3Key string `json:"s3_key"` + SizeBytes int64 `json:"size_bytes"` + TriggeredBy string `json:"triggered_by"` // manual, scheduled + ErrorMsg string `json:"error_message,omitempty"` + StartedAt string `json:"started_at"` + FinishedAt string `json:"finished_at,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` // 过期时间 +} + +// BackupService 数据库备份恢复服务 +type BackupService struct { + settingRepo SettingRepository + dbCfg *config.DatabaseConfig + encryptor SecretEncryptor + storeFactory BackupObjectStoreFactory + dumper DBDumper + + mu sync.Mutex + store BackupObjectStore + s3Cfg *BackupS3Config + backingUp bool + restoring bool + + recordsMu sync.Mutex // 保护 records 的 load/save 操作 + + cronMu sync.Mutex + cronSched *cron.Cron + cronEntryID cron.EntryID +} + +func NewBackupService( + settingRepo SettingRepository, + cfg *config.Config, + encryptor SecretEncryptor, + storeFactory BackupObjectStoreFactory, + dumper DBDumper, +) *BackupService { + return &BackupService{ + settingRepo: settingRepo, + dbCfg: &cfg.Database, + encryptor: encryptor, + storeFactory: storeFactory, + dumper: dumper, + } +} + +// Start 启动定时备份调度器 +func (s *BackupService) Start() { + s.cronSched = cron.New() + s.cronSched.Start() + + // 加载已有的定时配置 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + schedule, err := s.GetSchedule(ctx) + if err != nil { + logger.LegacyPrintf("service.backup", "[Backup] 加载定时备份配置失败: %v", err) + return + } + if schedule.Enabled && schedule.CronExpr != "" { + if err := s.applyCronSchedule(schedule); err != nil { + logger.LegacyPrintf("service.backup", "[Backup] 应用定时备份配置失败: %v", err) + } + } +} + +// Stop 停止定时备份 +func (s *BackupService) Stop() { + s.cronMu.Lock() + defer s.cronMu.Unlock() + if s.cronSched != nil { + s.cronSched.Stop() + } +} + +// ─── S3 配置管理 ─── + +func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) { + cfg, err := s.loadS3Config(ctx) + if err != nil { + return nil, err + } + if cfg == nil { + return &BackupS3Config{}, nil + } + // 脱敏返回 + cfg.SecretAccessKey = "" + return cfg, nil +} + +func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) { + // 如果没提供 secret,保留原有值 + if cfg.SecretAccessKey == "" { + old, _ := s.loadS3Config(ctx) + if old != nil { + cfg.SecretAccessKey = old.SecretAccessKey + } + } else { + // 加密 SecretAccessKey + encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey) + if err != nil { + return nil, fmt.Errorf("encrypt secret: %w", err) + } + cfg.SecretAccessKey = encrypted + } + + data, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal s3 config: %w", err) + } + if err := s.settingRepo.Set(ctx, settingKeyBackupS3Config, string(data)); err != nil { + return nil, fmt.Errorf("save s3 config: %w", err) + } + + // 清除缓存的 S3 客户端 + s.mu.Lock() + s.store = nil + s.s3Cfg = nil + s.mu.Unlock() + + cfg.SecretAccessKey = "" + return &cfg, nil +} + +func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config) error { + // 如果没提供 secret,用已保存的 + if cfg.SecretAccessKey == "" { + old, _ := s.loadS3Config(ctx) + if old != nil { + cfg.SecretAccessKey = old.SecretAccessKey + } + } + + if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" { + return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required") + } + + store, err := s.storeFactory(ctx, &cfg) + if err != nil { + return err + } + return store.HeadBucket(ctx) +} + +// ─── 定时备份管理 ─── + +func (s *BackupService) GetSchedule(ctx context.Context) (*BackupScheduleConfig, error) { + raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupSchedule) + if err != nil || raw == "" { + return &BackupScheduleConfig{}, nil + } + var cfg BackupScheduleConfig + if err := json.Unmarshal([]byte(raw), &cfg); err != nil { + return &BackupScheduleConfig{}, nil + } + return &cfg, nil +} + +func (s *BackupService) UpdateSchedule(ctx context.Context, cfg BackupScheduleConfig) (*BackupScheduleConfig, error) { + if cfg.Enabled && cfg.CronExpr == "" { + return nil, infraerrors.BadRequest("INVALID_CRON", "cron expression is required when schedule is enabled") + } + // 验证 cron 表达式 + if cfg.CronExpr != "" { + parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow) + if _, err := parser.Parse(cfg.CronExpr); err != nil { + return nil, infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("invalid cron expression: %v", err)) + } + } + + data, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal schedule config: %w", err) + } + if err := s.settingRepo.Set(ctx, settingKeyBackupSchedule, string(data)); err != nil { + return nil, fmt.Errorf("save schedule config: %w", err) + } + + // 应用或停止定时任务 + if cfg.Enabled { + if err := s.applyCronSchedule(&cfg); err != nil { + return nil, err + } + } else { + s.removeCronSchedule() + } + + return &cfg, nil +} + +func (s *BackupService) applyCronSchedule(cfg *BackupScheduleConfig) error { + s.cronMu.Lock() + defer s.cronMu.Unlock() + + if s.cronSched == nil { + return fmt.Errorf("cron scheduler not initialized") + } + + // 移除旧任务 + if s.cronEntryID != 0 { + s.cronSched.Remove(s.cronEntryID) + s.cronEntryID = 0 + } + + entryID, err := s.cronSched.AddFunc(cfg.CronExpr, func() { + s.runScheduledBackup() + }) + if err != nil { + return infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("failed to schedule: %v", err)) + } + s.cronEntryID = entryID + logger.LegacyPrintf("service.backup", "[Backup] 定时备份已启用: %s", cfg.CronExpr) + return nil +} + +func (s *BackupService) removeCronSchedule() { + s.cronMu.Lock() + defer s.cronMu.Unlock() + if s.cronSched != nil && s.cronEntryID != 0 { + s.cronSched.Remove(s.cronEntryID) + s.cronEntryID = 0 + logger.LegacyPrintf("service.backup", "[Backup] 定时备份已停用") + } +} + +func (s *BackupService) runScheduledBackup() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + defer cancel() + + // 读取定时备份配置中的过期天数 + schedule, _ := s.GetSchedule(ctx) + expireDays := 14 // 默认14天过期 + if schedule != nil && schedule.RetainDays > 0 { + expireDays = schedule.RetainDays + } + + logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays) + record, err := s.CreateBackup(ctx, "scheduled", expireDays) + if err != nil { + logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err) + return + } + logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes) + + // 清理过期备份(复用已加载的 schedule) + if schedule == nil { + return + } + if err := s.cleanupOldBackups(ctx, schedule); err != nil { + logger.LegacyPrintf("service.backup", "[Backup] 清理过期备份失败: %v", err) + } +} + +// ─── 备份/恢复核心 ─── + +// CreateBackup 创建全量数据库备份并上传到 S3(流式处理) +// expireDays: 备份过期天数,0=永不过期,默认14天 +func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) { + s.mu.Lock() + if s.backingUp { + s.mu.Unlock() + return nil, ErrBackupInProgress + } + s.backingUp = true + s.mu.Unlock() + defer func() { + s.mu.Lock() + s.backingUp = false + s.mu.Unlock() + }() + + s3Cfg, err := s.loadS3Config(ctx) + if err != nil { + return nil, err + } + if s3Cfg == nil || !s3Cfg.IsConfigured() { + return nil, ErrBackupS3NotConfigured + } + + objectStore, err := s.getOrCreateStore(ctx, s3Cfg) + if err != nil { + return nil, fmt.Errorf("init object store: %w", err) + } + + now := time.Now() + backupID := uuid.New().String()[:8] + fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405")) + s3Key := s.buildS3Key(s3Cfg, fileName) + + var expiresAt string + if expireDays > 0 { + expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339) + } + + record := &BackupRecord{ + ID: backupID, + Status: "running", + BackupType: "postgres", + FileName: fileName, + S3Key: s3Key, + TriggeredBy: triggeredBy, + StartedAt: now.Format(time.RFC3339), + ExpiresAt: expiresAt, + } + + // 流式执行: pg_dump -> gzip -> S3 upload + dumpReader, err := s.dumper.Dump(ctx) + if err != nil { + record.Status = "failed" + record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err) + record.FinishedAt = time.Now().Format(time.RFC3339) + _ = s.saveRecord(ctx, record) + return record, fmt.Errorf("pg_dump: %w", err) + } + + // 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传 + pr, pw := io.Pipe() + var gzipErr error + go func() { + gzWriter := gzip.NewWriter(pw) + _, gzipErr = io.Copy(gzWriter, dumpReader) + if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil { + gzipErr = closeErr + } + if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil { + gzipErr = closeErr + } + if gzipErr != nil { + _ = pw.CloseWithError(gzipErr) + } else { + _ = pw.Close() + } + }() + + contentType := "application/gzip" + sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType) + if err != nil { + record.Status = "failed" + errMsg := fmt.Sprintf("S3 upload failed: %v", err) + if gzipErr != nil { + errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr) + } + record.ErrorMsg = errMsg + record.FinishedAt = time.Now().Format(time.RFC3339) + _ = s.saveRecord(ctx, record) + return record, fmt.Errorf("backup upload: %w", err) + } + + record.SizeBytes = sizeBytes + record.Status = "completed" + record.FinishedAt = time.Now().Format(time.RFC3339) + if err := s.saveRecord(ctx, record); err != nil { + logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err) + } + + return record, nil +} + +// RestoreBackup 从 S3 下载备份并流式恢复到数据库 +func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error { + s.mu.Lock() + if s.restoring { + s.mu.Unlock() + return ErrRestoreInProgress + } + s.restoring = true + s.mu.Unlock() + defer func() { + s.mu.Lock() + s.restoring = false + s.mu.Unlock() + }() + + record, err := s.GetBackupRecord(ctx, backupID) + if err != nil { + return err + } + if record.Status != "completed" { + return infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup") + } + + s3Cfg, err := s.loadS3Config(ctx) + if err != nil { + return err + } + objectStore, err := s.getOrCreateStore(ctx, s3Cfg) + if err != nil { + return fmt.Errorf("init object store: %w", err) + } + + // 从 S3 流式下载 + body, err := objectStore.Download(ctx, record.S3Key) + if err != nil { + return fmt.Errorf("S3 download failed: %w", err) + } + defer func() { _ = body.Close() }() + + // 流式解压 gzip -> psql(不将全部数据加载到内存) + gzReader, err := gzip.NewReader(body) + if err != nil { + return fmt.Errorf("gzip reader: %w", err) + } + defer func() { _ = gzReader.Close() }() + + // 流式恢复 + if err := s.dumper.Restore(ctx, gzReader); err != nil { + return fmt.Errorf("pg restore: %w", err) + } + + return nil +} + +// ─── 备份记录管理 ─── + +func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) { + records, err := s.loadRecords(ctx) + if err != nil { + return nil, err + } + // 倒序返回(最新在前) + sort.Slice(records, func(i, j int) bool { + return records[i].StartedAt > records[j].StartedAt + }) + return records, nil +} + +func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*BackupRecord, error) { + records, err := s.loadRecords(ctx) + if err != nil { + return nil, err + } + for i := range records { + if records[i].ID == backupID { + return &records[i], nil + } + } + return nil, ErrBackupNotFound +} + +func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error { + s.recordsMu.Lock() + defer s.recordsMu.Unlock() + + records, err := s.loadRecordsLocked(ctx) + if err != nil { + return err + } + + var found *BackupRecord + var remaining []BackupRecord + for i := range records { + if records[i].ID == backupID { + found = &records[i] + } else { + remaining = append(remaining, records[i]) + } + } + if found == nil { + return ErrBackupNotFound + } + + // 从 S3 删除 + if found.S3Key != "" && found.Status == "completed" { + s3Cfg, err := s.loadS3Config(ctx) + if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() { + objectStore, err := s.getOrCreateStore(ctx, s3Cfg) + if err == nil { + _ = objectStore.Delete(ctx, found.S3Key) + } + } + } + + return s.saveRecordsLocked(ctx, remaining) +} + +// GetBackupDownloadURL 获取备份文件预签名下载 URL +func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID string) (string, error) { + record, err := s.GetBackupRecord(ctx, backupID) + if err != nil { + return "", err + } + if record.Status != "completed" { + return "", infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "backup is not completed") + } + + s3Cfg, err := s.loadS3Config(ctx) + if err != nil { + return "", err + } + objectStore, err := s.getOrCreateStore(ctx, s3Cfg) + if err != nil { + return "", err + } + + url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour) + if err != nil { + return "", fmt.Errorf("presign url: %w", err) + } + return url, nil +} + +// ─── 内部方法 ─── + +func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) { + raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config) + if err != nil || raw == "" { + return nil, nil //nolint:nilnil // no config is a valid state + } + var cfg BackupS3Config + if err := json.Unmarshal([]byte(raw), &cfg); err != nil { + return nil, ErrBackupS3ConfigCorrupt + } + // 解密 SecretAccessKey + if cfg.SecretAccessKey != "" { + decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey) + if err != nil { + // 兼容未加密的旧数据:如果解密失败,保持原值 + logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err) + } else { + cfg.SecretAccessKey = decrypted + } + } + return &cfg, nil +} + +func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.store != nil && s.s3Cfg != nil { + return s.store, nil + } + + if cfg == nil { + return nil, ErrBackupS3NotConfigured + } + + store, err := s.storeFactory(ctx, cfg) + if err != nil { + return nil, err + } + s.store = store + s.s3Cfg = cfg + return store, nil +} + +func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string { + prefix := strings.TrimRight(cfg.Prefix, "/") + if prefix == "" { + prefix = "backups" + } + return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName) +} + +// loadRecords 加载备份记录,区分"无数据"和"数据损坏" +func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) { + s.recordsMu.Lock() + defer s.recordsMu.Unlock() + return s.loadRecordsLocked(ctx) +} + +// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录 +func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) { + raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords) + if err != nil || raw == "" { + return nil, nil //nolint:nilnil // no records is a valid state + } + var records []BackupRecord + if err := json.Unmarshal([]byte(raw), &records); err != nil { + return nil, ErrBackupRecordsCorrupt + } + return records, nil +} + +// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录 +func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error { + data, err := json.Marshal(records) + if err != nil { + return err + } + return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data)) +} + +// saveRecord 保存单条记录(带互斥锁保护) +func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error { + s.recordsMu.Lock() + defer s.recordsMu.Unlock() + + records, _ := s.loadRecordsLocked(ctx) + + // 更新已有记录或追加 + found := false + for i := range records { + if records[i].ID == record.ID { + records[i] = *record + found = true + break + } + } + if !found { + records = append(records, *record) + } + + // 限制记录数量 + if len(records) > maxBackupRecords { + records = records[len(records)-maxBackupRecords:] + } + + return s.saveRecordsLocked(ctx, records) +} + +func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error { + if schedule == nil { + return nil + } + + s.recordsMu.Lock() + defer s.recordsMu.Unlock() + + records, err := s.loadRecordsLocked(ctx) + if err != nil { + return err + } + + // 按时间倒序 + sort.Slice(records, func(i, j int) bool { + return records[i].StartedAt > records[j].StartedAt + }) + + var toDelete []BackupRecord + var toKeep []BackupRecord + + for i, r := range records { + shouldDelete := false + + // 按保留份数清理 + if schedule.RetainCount > 0 && i >= schedule.RetainCount { + shouldDelete = true + } + + // 按保留天数清理 + if schedule.RetainDays > 0 && r.StartedAt != "" { + startedAt, err := time.Parse(time.RFC3339, r.StartedAt) + if err == nil && time.Since(startedAt) > time.Duration(schedule.RetainDays)*24*time.Hour { + shouldDelete = true + } + } + + if shouldDelete && r.Status == "completed" { + toDelete = append(toDelete, r) + } else { + toKeep = append(toKeep, r) + } + } + + // 删除 S3 上的文件 + for _, r := range toDelete { + if r.S3Key != "" { + _ = s.deleteS3Object(ctx, r.S3Key) + } + } + + if len(toDelete) > 0 { + logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete)) + return s.saveRecordsLocked(ctx, toKeep) + } + return nil +} + +func (s *BackupService) deleteS3Object(ctx context.Context, key string) error { + s3Cfg, err := s.loadS3Config(ctx) + if err != nil || s3Cfg == nil { + return nil + } + objectStore, err := s.getOrCreateStore(ctx, s3Cfg) + if err != nil { + return err + } + return objectStore.Delete(ctx, key) +} diff --git a/backend/internal/service/backup_service_test.go b/backend/internal/service/backup_service_test.go new file mode 100644 index 00000000..e752997c --- /dev/null +++ b/backend/internal/service/backup_service_test.go @@ -0,0 +1,528 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +// ─── Mocks ─── + +type mockSettingRepo struct { + mu sync.Mutex + data map[string]string +} + +func newMockSettingRepo() *mockSettingRepo { + return &mockSettingRepo{data: make(map[string]string)} +} + +func (m *mockSettingRepo) Get(_ context.Context, key string) (*Setting, error) { + m.mu.Lock() + defer m.mu.Unlock() + v, ok := m.data[key] + if !ok { + return nil, ErrSettingNotFound + } + return &Setting{Key: key, Value: v}, nil +} + +func (m *mockSettingRepo) GetValue(_ context.Context, key string) (string, error) { + m.mu.Lock() + defer m.mu.Unlock() + v, ok := m.data[key] + if !ok { + return "", nil + } + return v, nil +} + +func (m *mockSettingRepo) Set(_ context.Context, key, value string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.data[key] = value + return nil +} + +func (m *mockSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + result := make(map[string]string) + for _, k := range keys { + if v, ok := m.data[k]; ok { + result[k] = v + } + } + return result, nil +} + +func (m *mockSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error { + m.mu.Lock() + defer m.mu.Unlock() + for k, v := range settings { + m.data[k] = v + } + return nil +} + +func (m *mockSettingRepo) GetAll(_ context.Context) (map[string]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + result := make(map[string]string, len(m.data)) + for k, v := range m.data { + result[k] = v + } + return result, nil +} + +func (m *mockSettingRepo) Delete(_ context.Context, key string) error { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.data, key) + return nil +} + +// plainEncryptor 仅做 base64-like 包装,用于测试 +type plainEncryptor struct{} + +func (e *plainEncryptor) Encrypt(plaintext string) (string, error) { + return "ENC:" + plaintext, nil +} + +func (e *plainEncryptor) Decrypt(ciphertext string) (string, error) { + if strings.HasPrefix(ciphertext, "ENC:") { + return strings.TrimPrefix(ciphertext, "ENC:"), nil + } + return ciphertext, fmt.Errorf("not encrypted") +} + +type mockDumper struct { + dumpData []byte + dumpErr error + restored []byte + restErr error +} + +func (m *mockDumper) Dump(_ context.Context) (io.ReadCloser, error) { + if m.dumpErr != nil { + return nil, m.dumpErr + } + return io.NopCloser(bytes.NewReader(m.dumpData)), nil +} + +func (m *mockDumper) Restore(_ context.Context, data io.Reader) error { + if m.restErr != nil { + return m.restErr + } + d, err := io.ReadAll(data) + if err != nil { + return err + } + m.restored = d + return nil +} + +type mockObjectStore struct { + objects map[string][]byte + mu sync.Mutex +} + +func newMockObjectStore() *mockObjectStore { + return &mockObjectStore{objects: make(map[string][]byte)} +} + +func (m *mockObjectStore) Upload(_ context.Context, key string, body io.Reader, _ string) (int64, error) { + data, err := io.ReadAll(body) + if err != nil { + return 0, err + } + m.mu.Lock() + m.objects[key] = data + m.mu.Unlock() + return int64(len(data)), nil +} + +func (m *mockObjectStore) Download(_ context.Context, key string) (io.ReadCloser, error) { + m.mu.Lock() + data, ok := m.objects[key] + m.mu.Unlock() + if !ok { + return nil, fmt.Errorf("not found: %s", key) + } + return io.NopCloser(bytes.NewReader(data)), nil +} + +func (m *mockObjectStore) Delete(_ context.Context, key string) error { + m.mu.Lock() + delete(m.objects, key) + m.mu.Unlock() + return nil +} + +func (m *mockObjectStore) PresignURL(_ context.Context, key string, _ time.Duration) (string, error) { + return "https://presigned.example.com/" + key, nil +} + +func (m *mockObjectStore) HeadBucket(_ context.Context) error { + return nil +} + +func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService { + cfg := &config.Config{ + Database: config.DatabaseConfig{ + Host: "localhost", + Port: 5432, + User: "test", + DBName: "testdb", + }, + } + factory := func(_ context.Context, _ *BackupS3Config) (BackupObjectStore, error) { + return store, nil + } + return NewBackupService(repo, cfg, &plainEncryptor{}, factory, dumper) +} + +func seedS3Config(t *testing.T, repo *mockSettingRepo) { + t.Helper() + cfg := BackupS3Config{ + Bucket: "test-bucket", + AccessKeyID: "AKID", + SecretAccessKey: "ENC:secret123", + Prefix: "backups", + } + data, _ := json.Marshal(cfg) + require.NoError(t, repo.Set(context.Background(), settingKeyBackupS3Config, string(data))) +} + +// ─── Tests ─── + +func TestBackupService_S3ConfigEncryption(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + // 保存配置 -> SecretAccessKey 应被加密 + _, err := svc.UpdateS3Config(context.Background(), BackupS3Config{ + Bucket: "my-bucket", + AccessKeyID: "AKID", + SecretAccessKey: "my-secret", + Prefix: "backups", + }) + require.NoError(t, err) + + // 直接读取数据库中存储的值,应该是加密后的 + raw, _ := repo.GetValue(context.Background(), settingKeyBackupS3Config) + var stored BackupS3Config + require.NoError(t, json.Unmarshal([]byte(raw), &stored)) + require.Equal(t, "ENC:my-secret", stored.SecretAccessKey) + + // 通过 GetS3Config 获取应该脱敏 + cfg, err := svc.GetS3Config(context.Background()) + require.NoError(t, err) + require.Empty(t, cfg.SecretAccessKey) + require.Equal(t, "my-bucket", cfg.Bucket) + + // loadS3Config 内部应解密 + internal, err := svc.loadS3Config(context.Background()) + require.NoError(t, err) + require.Equal(t, "my-secret", internal.SecretAccessKey) +} + +func TestBackupService_S3ConfigKeepExistingSecret(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + // 先保存一个有 secret 的配置 + _, err := svc.UpdateS3Config(context.Background(), BackupS3Config{ + Bucket: "my-bucket", + AccessKeyID: "AKID", + SecretAccessKey: "original-secret", + }) + require.NoError(t, err) + + // 再更新时不提供 secret,应保留原值 + _, err = svc.UpdateS3Config(context.Background(), BackupS3Config{ + Bucket: "my-bucket", + AccessKeyID: "AKID-NEW", + }) + require.NoError(t, err) + + internal, err := svc.loadS3Config(context.Background()) + require.NoError(t, err) + require.Equal(t, "original-secret", internal.SecretAccessKey) + require.Equal(t, "AKID-NEW", internal.AccessKeyID) +} + +func TestBackupService_SaveRecordConcurrency(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + var wg sync.WaitGroup + n := 20 + wg.Add(n) + for i := 0; i < n; i++ { + go func(idx int) { + defer wg.Done() + record := &BackupRecord{ + ID: fmt.Sprintf("rec-%d", idx), + Status: "completed", + StartedAt: time.Now().Format(time.RFC3339), + } + _ = svc.saveRecord(context.Background(), record) + }(i) + } + wg.Wait() + + records, err := svc.loadRecords(context.Background()) + require.NoError(t, err) + require.Len(t, records, n) +} + +func TestBackupService_LoadRecords_Empty(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + records, err := svc.loadRecords(context.Background()) + require.NoError(t, err) + require.Nil(t, records) // 无数据时返回 nil +} + +func TestBackupService_LoadRecords_Corrupted(t *testing.T) { + repo := newMockSettingRepo() + _ = repo.Set(context.Background(), settingKeyBackupRecords, "not valid json{{{") + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + records, err := svc.loadRecords(context.Background()) + require.Error(t, err) // 损坏数据应返回错误 + require.Nil(t, records) +} + +func TestBackupService_CreateBackup_Streaming(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n" + dumper := &mockDumper{dumpData: []byte(dumpContent)} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + record, err := svc.CreateBackup(context.Background(), "manual", 14) + require.NoError(t, err) + require.Equal(t, "completed", record.Status) + require.Greater(t, record.SizeBytes, int64(0)) + require.NotEmpty(t, record.S3Key) + + // 验证 S3 上确实有文件 + store.mu.Lock() + require.Len(t, store.objects, 1) + store.mu.Unlock() +} + +func TestBackupService_CreateBackup_DumpFailure(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumper := &mockDumper{dumpErr: fmt.Errorf("pg_dump failed")} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + record, err := svc.CreateBackup(context.Background(), "manual", 14) + require.Error(t, err) + require.Equal(t, "failed", record.Status) + require.Contains(t, record.ErrorMsg, "pg_dump") +} + +func TestBackupService_CreateBackup_NoS3Config(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + _, err := svc.CreateBackup(context.Background(), "manual", 14) + require.ErrorIs(t, err, ErrBackupS3NotConfigured) +} + +func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + // 使用一个慢速 dumper 来模拟正在进行的备份 + dumper := &mockDumper{dumpData: []byte("data")} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + // 手动设置 backingUp 标志 + svc.mu.Lock() + svc.backingUp = true + svc.mu.Unlock() + + _, err := svc.CreateBackup(context.Background(), "manual", 14) + require.ErrorIs(t, err, ErrBackupInProgress) +} + +func TestBackupService_RestoreBackup_Streaming(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n" + dumper := &mockDumper{dumpData: []byte(dumpContent)} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + // 先创建一个备份 + record, err := svc.CreateBackup(context.Background(), "manual", 14) + require.NoError(t, err) + + // 恢复 + err = svc.RestoreBackup(context.Background(), record.ID) + require.NoError(t, err) + + // 验证 psql 收到的数据是否与原始 dump 内容一致 + require.Equal(t, dumpContent, string(dumper.restored)) +} + +func TestBackupService_RestoreBackup_NotCompleted(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + // 手动插入一条 failed 记录 + _ = svc.saveRecord(context.Background(), &BackupRecord{ + ID: "fail-1", + Status: "failed", + }) + + err := svc.RestoreBackup(context.Background(), "fail-1") + require.Error(t, err) +} + +func TestBackupService_DeleteBackup(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumpContent := "data" + dumper := &mockDumper{dumpData: []byte(dumpContent)} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + record, err := svc.CreateBackup(context.Background(), "manual", 14) + require.NoError(t, err) + + // S3 中应有文件 + store.mu.Lock() + require.Len(t, store.objects, 1) + store.mu.Unlock() + + // 删除 + err = svc.DeleteBackup(context.Background(), record.ID) + require.NoError(t, err) + + // S3 中文件应被删除 + store.mu.Lock() + require.Len(t, store.objects, 0) + store.mu.Unlock() + + // 记录应不存在 + _, err = svc.GetBackupRecord(context.Background(), record.ID) + require.ErrorIs(t, err, ErrBackupNotFound) +} + +func TestBackupService_GetDownloadURL(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumper := &mockDumper{dumpData: []byte("data")} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + record, err := svc.CreateBackup(context.Background(), "manual", 14) + require.NoError(t, err) + + url, err := svc.GetBackupDownloadURL(context.Background(), record.ID) + require.NoError(t, err) + require.Contains(t, url, "https://presigned.example.com/") +} + +func TestBackupService_ListBackups_Sorted(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + now := time.Now() + for i := 0; i < 3; i++ { + _ = svc.saveRecord(context.Background(), &BackupRecord{ + ID: fmt.Sprintf("rec-%d", i), + Status: "completed", + StartedAt: now.Add(time.Duration(i) * time.Hour).Format(time.RFC3339), + }) + } + + records, err := svc.ListBackups(context.Background()) + require.NoError(t, err) + require.Len(t, records, 3) + // 最新在前 + require.Equal(t, "rec-2", records[0].ID) + require.Equal(t, "rec-0", records[2].ID) +} + +func TestBackupService_TestS3Connection(t *testing.T) { + repo := newMockSettingRepo() + store := newMockObjectStore() + svc := newTestBackupService(repo, &mockDumper{}, store) + + err := svc.TestS3Connection(context.Background(), BackupS3Config{ + Bucket: "test", + AccessKeyID: "ak", + SecretAccessKey: "sk", + }) + require.NoError(t, err) +} + +func TestBackupService_TestS3Connection_Incomplete(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + err := svc.TestS3Connection(context.Background(), BackupS3Config{ + Bucket: "test", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "incomplete") +} + +func TestBackupService_Schedule_CronValidation(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + svc.cronSched = nil // 未初始化 cron + + // 启用但 cron 为空 + _, err := svc.UpdateSchedule(context.Background(), BackupScheduleConfig{ + Enabled: true, + CronExpr: "", + }) + require.Error(t, err) + + // 无效的 cron 表达式 + _, err = svc.UpdateSchedule(context.Background(), BackupScheduleConfig{ + Enabled: true, + CronExpr: "invalid", + }) + require.Error(t, err) +} + +func TestBackupService_LoadS3Config_Corrupted(t *testing.T) { + repo := newMockSettingRepo() + _ = repo.Set(context.Background(), settingKeyBackupS3Config, "not json!!!!") + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + cfg, err := svc.loadS3Config(context.Background()) + require.Error(t, err) + require.Nil(t, cfg) +} diff --git a/backend/internal/service/bedrock_request.go b/backend/internal/service/bedrock_request.go new file mode 100644 index 00000000..2160c13c --- /dev/null +++ b/backend/internal/service/bedrock_request.go @@ -0,0 +1,607 @@ +package service + +import ( + "encoding/json" + "fmt" + "net/url" + "regexp" + "strconv" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/domain" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const defaultBedrockRegion = "us-east-1" + +var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."} + +// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀 +// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html +func BedrockCrossRegionPrefix(region string) string { + switch { + case strings.HasPrefix(region, "us-gov"): + return "us-gov" // GovCloud 使用独立的 us-gov 前缀 + case strings.HasPrefix(region, "us-"): + return "us" + case strings.HasPrefix(region, "eu-"): + return "eu" + case region == "ap-northeast-1": + return "jp" // 日本区域使用独立的 jp 前缀(AWS 官方定义) + case region == "ap-southeast-2": + return "au" // 澳大利亚区域使用独立的 au 前缀(AWS 官方定义) + case strings.HasPrefix(region, "ap-"): + return "apac" // 其余亚太区域使用通用 apac 前缀 + case strings.HasPrefix(region, "ca-"): + return "us" // 加拿大区域使用 us 前缀的跨区域推理 + case strings.HasPrefix(region, "sa-"): + return "us" // 南美区域使用 us 前缀的跨区域推理 + default: + return "us" + } +} + +// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀 +// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1" +// 特殊值 region="global" 强制使用 global. 前缀 +func AdjustBedrockModelRegionPrefix(modelID, region string) string { + var targetPrefix string + if region == "global" { + targetPrefix = "global" + } else { + targetPrefix = BedrockCrossRegionPrefix(region) + } + + for _, p := range bedrockCrossRegionPrefixes { + if strings.HasPrefix(modelID, p) { + if p == targetPrefix+"." { + return modelID // 前缀已匹配,无需替换 + } + return targetPrefix + "." + modelID[len(p):] + } + } + + // 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改 + return modelID +} + +func bedrockRuntimeRegion(account *Account) string { + if account == nil { + return defaultBedrockRegion + } + if region := account.GetCredential("aws_region"); region != "" { + return region + } + return defaultBedrockRegion +} + +func shouldForceBedrockGlobal(account *Account) bool { + return account != nil && account.GetCredential("aws_force_global") == "true" +} + +func isRegionalBedrockModelID(modelID string) bool { + for _, prefix := range bedrockCrossRegionPrefixes { + if strings.HasPrefix(modelID, prefix) { + return true + } + } + return false +} + +func isLikelyBedrockModelID(modelID string) bool { + lower := strings.ToLower(strings.TrimSpace(modelID)) + if lower == "" { + return false + } + if strings.HasPrefix(lower, "arn:") { + return true + } + for _, prefix := range []string{ + "anthropic.", + "amazon.", + "meta.", + "mistral.", + "cohere.", + "ai21.", + "deepseek.", + "stability.", + "writer.", + "nova.", + } { + if strings.HasPrefix(lower, prefix) { + return true + } + } + return isRegionalBedrockModelID(lower) +} + +func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + return "", false, false + } + if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists { + return mapped, true, true + } + if isRegionalBedrockModelID(modelID) { + return modelID, true, true + } + if isLikelyBedrockModelID(modelID) { + return modelID, false, true + } + return "", false, false +} + +// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID. +// It applies account model_mapping first, then default Bedrock aliases, and finally +// adjusts Anthropic cross-region prefixes to match the account region. +func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) { + if account == nil { + return "", false + } + + mappedModel := account.GetMappedModel(requestedModel) + modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel) + if !ok { + return "", false + } + if shouldAdjustRegion { + targetRegion := bedrockRuntimeRegion(account) + if shouldForceBedrockGlobal(account) { + targetRegion = "global" + } + modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion) + } + return modelID, true +} + +// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL +// stream=true 时使用 invoke-with-response-stream 端点 +// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐) +func BuildBedrockURL(region, modelID string, stream bool) string { + if region == "" { + region = defaultBedrockRegion + } + encodedModelID := url.PathEscape(modelID) + // url.PathEscape 不编码冒号(RFC 允许 path 中出现 ":"), + // 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A + encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A") + if stream { + return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID) + } + return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID) +} + +// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API +// 1. 注入 anthropic_version +// 2. 注入 anthropic_beta(从客户端 anthropic-beta 头解析) +// 3. 移除 Bedrock 不支持的字段(model, stream, output_format, output_config) +// 4. 移除工具定义中的 custom 字段(Claude Code 会发送 custom: {defer_loading: true}) +// 5. 清理 cache_control 中 Bedrock 不支持的字段(scope, ttl) +func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) { + betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID) + return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens) +} + +// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens. +func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) { + var err error + + // 注入 anthropic_version(Bedrock 要求) + body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31") + if err != nil { + return nil, fmt.Errorf("inject anthropic_version: %w", err) + } + + // 注入 anthropic_beta(Bedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头) + // 1. 从客户端 anthropic-beta header 解析 + // 2. 根据请求体内容自动补齐必要的 beta token + // 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock() + if len(betaTokens) > 0 { + body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens) + if err != nil { + return nil, fmt.Errorf("inject anthropic_beta: %w", err) + } + } + + // 移除 model 字段(Bedrock 通过 URL 指定模型) + body, err = sjson.DeleteBytes(body, "model") + if err != nil { + return nil, fmt.Errorf("remove model field: %w", err) + } + + // 移除 stream 字段(Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段) + body, err = sjson.DeleteBytes(body, "stream") + if err != nil { + return nil, fmt.Errorf("remove stream field: %w", err) + } + + // 转换 output_format(Bedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message) + // 参考 litellm: _convert_output_format_to_inline_schema() + body = convertOutputFormatToInlineSchema(body) + + // 移除 output_config 字段(Bedrock Invoke 不支持) + body, err = sjson.DeleteBytes(body, "output_config") + if err != nil { + return nil, fmt.Errorf("remove output_config field: %w", err) + } + + // 移除工具定义中的 custom 字段 + // Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true}, + // Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted" + body = removeCustomFieldFromTools(body) + + // 清理 cache_control 中 Bedrock 不支持的字段 + body = sanitizeBedrockCacheControl(body, modelID) + + return body, nil +} + +// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering. +func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string { + betaTokens := parseAnthropicBetaHeader(betaHeader) + betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID) + return filterBedrockBetaTokens(betaTokens) +} + +// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message +// Bedrock Invoke 不支持 output_format 参数,litellm 的做法是将 schema 追加到用户消息中 +// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema() +func convertOutputFormatToInlineSchema(body []byte) []byte { + outputFormat := gjson.GetBytes(body, "output_format") + if !outputFormat.Exists() || !outputFormat.IsObject() { + return body + } + + // 先从请求体中移除 output_format + body, _ = sjson.DeleteBytes(body, "output_format") + + schema := outputFormat.Get("schema") + if !schema.Exists() { + return body + } + + // 找到最后一条 user message + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return body + } + msgArr := messages.Array() + lastUserIdx := -1 + for i := len(msgArr) - 1; i >= 0; i-- { + if msgArr[i].Get("role").String() == "user" { + lastUserIdx = i + break + } + } + if lastUserIdx < 0 { + return body + } + + // 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组 + schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw)) + if err != nil { + return body + } + + content := msgArr[lastUserIdx].Get("content") + basePath := fmt.Sprintf("messages.%d.content", lastUserIdx) + + if content.IsArray() { + // 追加一个 text block 到 content 数组末尾 + idx := len(content.Array()) + body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text") + body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON)) + } else if content.Type == gjson.String { + // content 是纯字符串,转换为数组格式 + originalText := content.String() + body, _ = sjson.SetBytes(body, basePath, []map[string]string{ + {"type": "text", "text": originalText}, + {"type": "text", "text": string(schemaJSON)}, + }) + } + + return body +} + +// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段 +func removeCustomFieldFromTools(body []byte) []byte { + tools := gjson.GetBytes(body, "tools") + if !tools.Exists() || !tools.IsArray() { + return body + } + var err error + for i := range tools.Array() { + body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i)) + if err != nil { + // 删除失败不影响整体流程,跳过 + continue + } + } + return body +} + +// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分 +// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式 +var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`) + +// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本 +// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h") +func isBedrockClaude45OrNewer(modelID string) bool { + lower := strings.ToLower(modelID) + matches := claudeVersionRe.FindStringSubmatch(lower) + if matches == nil { + return false + } + major, _ := strconv.Atoi(matches[1]) + minor, _ := strconv.Atoi(matches[2]) + return major > 4 || (major == 4 && minor >= 5) +} + +// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里 +// Bedrock 不支持的字段: +// - scope:Bedrock 不支持(如 "global" 跨请求缓存) +// - ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除 +func sanitizeBedrockCacheControl(body []byte, modelID string) []byte { + isClaude45 := isBedrockClaude45OrNewer(modelID) + + // 清理 system 数组中的 cache_control + systemArr := gjson.GetBytes(body, "system") + if systemArr.Exists() && systemArr.IsArray() { + for i, item := range systemArr.Array() { + if !item.IsObject() { + continue + } + cc := item.Get("cache_control") + if !cc.Exists() || !cc.IsObject() { + continue + } + body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45) + } + } + + // 清理 messages 中的 cache_control + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return body + } + for mi, msg := range messages.Array() { + if !msg.IsObject() { + continue + } + content := msg.Get("content") + if !content.Exists() || !content.IsArray() { + continue + } + for ci, block := range content.Array() { + if !block.IsObject() { + continue + } + cc := block.Get("cache_control") + if !cc.Exists() || !cc.IsObject() { + continue + } + body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45) + } + } + + return body +} + +// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段 +func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte { + // Bedrock 不支持 scope(如 "global") + if cc.Get("scope").Exists() { + body, _ = sjson.DeleteBytes(body, basePath+".scope") + } + + // ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除 + ttl := cc.Get("ttl") + if ttl.Exists() { + shouldRemove := true + if isClaude45 { + v := ttl.String() + if v == "5m" || v == "1h" { + shouldRemove = false + } + } + if shouldRemove { + body, _ = sjson.DeleteBytes(body, basePath+".ttl") + } + } + + return body +} + +// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表 +func parseAnthropicBetaHeader(header string) []string { + header = strings.TrimSpace(header) + if header == "" { + return nil + } + if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") { + var parsed []any + if err := json.Unmarshal([]byte(header), &parsed); err == nil { + tokens := make([]string, 0, len(parsed)) + for _, item := range parsed { + token := strings.TrimSpace(fmt.Sprint(item)) + if token != "" { + tokens = append(tokens, token) + } + } + return tokens + } + } + var tokens []string + for _, part := range strings.Split(header, ",") { + t := strings.TrimSpace(part) + if t != "" { + tokens = append(tokens, t) + } + } + return tokens +} + +// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单 +// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json) +// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单 +var bedrockSupportedBetaTokens = map[string]bool{ + "computer-use-2025-01-24": true, + "computer-use-2025-11-24": true, + "context-1m-2025-08-07": true, + "context-management-2025-06-27": true, + "compact-2026-01-12": true, + "interleaved-thinking-2025-05-14": true, + "tool-search-tool-2025-10-19": true, + "tool-examples-2025-10-29": true, +} + +// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则 +// Anthropic 直接 API 使用通用头,Bedrock Invoke 需要特定的替代头 +var bedrockBetaTokenTransforms = map[string]string{ + "advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19", +} + +// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token +// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和 +// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock() +// +// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token, +// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。 +func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string { + seen := make(map[string]bool, len(tokens)) + for _, t := range tokens { + seen[t] = true + } + + inject := func(token string) { + if !seen[token] { + tokens = append(tokens, token) + seen[token] = true + } + } + + // 检测 thinking / interleaved thinking + // 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta + if gjson.GetBytes(body, "thinking").Exists() { + inject("interleaved-thinking-2025-05-14") + } + + // 检测 computer_use 工具 + // tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta + tools := gjson.GetBytes(body, "tools") + if tools.Exists() && tools.IsArray() { + toolSearchUsed := false + programmaticToolCallingUsed := false + inputExamplesUsed := false + for _, tool := range tools.Array() { + toolType := tool.Get("type").String() + if strings.HasPrefix(toolType, "computer_20") { + inject("computer-use-2025-11-24") + } + if isBedrockToolSearchType(toolType) { + toolSearchUsed = true + } + if hasCodeExecutionAllowedCallers(tool) { + programmaticToolCallingUsed = true + } + if hasInputExamples(tool) { + inputExamplesUsed = true + } + } + if programmaticToolCallingUsed || inputExamplesUsed { + // programmatic tool calling 和 input examples 需要 advanced-tool-use, + // 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool + inject("advanced-tool-use-2025-11-20") + } + if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) { + // 纯 tool search(无 programmatic/inputExamples)时直接注入 Bedrock 特定头, + // 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐) + if !programmaticToolCallingUsed && !inputExamplesUsed { + inject("tool-search-tool-2025-10-19") + } else { + inject("advanced-tool-use-2025-11-20") + } + } + } + + return tokens +} + +func isBedrockToolSearchType(toolType string) bool { + return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119" +} + +func hasCodeExecutionAllowedCallers(tool gjson.Result) bool { + allowedCallers := tool.Get("allowed_callers") + if containsStringInJSONArray(allowedCallers, "code_execution_20250825") { + return true + } + return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825") +} + +func hasInputExamples(tool gjson.Result) bool { + if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 { + return true + } + arr := tool.Get("function.input_examples") + return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 +} + +func containsStringInJSONArray(result gjson.Result, target string) bool { + if !result.Exists() || !result.IsArray() { + return false + } + for _, item := range result.Array() { + if item.String() == target { + return true + } + } + return false +} + +// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search +// 目前仅 Claude Opus/Sonnet 4.5+ 支持,Haiku 不支持 +func bedrockModelSupportsToolSearch(modelID string) bool { + lower := strings.ToLower(modelID) + matches := claudeVersionRe.FindStringSubmatch(lower) + if matches == nil { + return false + } + // Haiku 不支持 tool search + if strings.Contains(lower, "haiku") { + return false + } + major, _ := strconv.Atoi(matches[1]) + minor, _ := strconv.Atoi(matches[2]) + return major > 4 || (major == 4 && minor >= 5) +} + +// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token +// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool) +// 2. 过滤掉 Bedrock 不支持的 token(如 output-128k, files-api, structured-outputs 等) +// 3. 自动关联 tool-examples(当 tool-search-tool 存在时) +func filterBedrockBetaTokens(tokens []string) []string { + seen := make(map[string]bool, len(tokens)) + var result []string + + for _, t := range tokens { + // 应用转换规则 + if replacement, ok := bedrockBetaTokenTransforms[t]; ok { + t = replacement + } + // 只保留白名单中的 token,且去重 + if bedrockSupportedBetaTokens[t] && !seen[t] { + result = append(result, t) + seen[t] = true + } + } + + // 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在 + if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] { + result = append(result, "tool-examples-2025-10-29") + } + + return result +} diff --git a/backend/internal/service/bedrock_request_test.go b/backend/internal/service/bedrock_request_test.go new file mode 100644 index 00000000..361cafb4 --- /dev/null +++ b/backend/internal/service/bedrock_request_test.go @@ -0,0 +1,659 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) { + input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "") + require.NoError(t, err) + + // anthropic_version 应被注入 + assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String()) + // model 和 stream 应被移除 + assert.False(t, gjson.GetBytes(result, "model").Exists()) + assert.False(t, gjson.GetBytes(result, "stream").Exists()) + // max_tokens 应保留 + assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int()) +} + +func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) { + t.Run("schema inlined into last user message array content", func(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + // schema 应内联到最后一条 user message 的 content 数组末尾 + contentArr := gjson.GetBytes(result, "messages.0.content").Array() + require.Len(t, contentArr, 2) + assert.Equal(t, "text", contentArr[1].Get("type").String()) + assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`) + }) + + t.Run("schema inlined into string content", func(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + contentArr := gjson.GetBytes(result, "messages.0.content").Array() + require.Len(t, contentArr, 2) + assert.Equal(t, "compute this", contentArr[0].Get("text").String()) + assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`) + }) + + t.Run("no schema field just removes output_format", func(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + }) + + t.Run("no messages just removes output_format", func(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + }) +} + +func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_config").Exists()) +} + +func TestRemoveCustomFieldFromTools(t *testing.T) { + input := `{ + "tools": [ + {"name":"tool1","custom":{"defer_loading":true},"description":"desc1"}, + {"name":"tool2","description":"desc2"}, + {"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"} + ] + }` + result := removeCustomFieldFromTools([]byte(input)) + + tools := gjson.GetBytes(result, "tools").Array() + require.Len(t, tools, 3) + // custom 应被移除 + assert.False(t, tools[0].Get("custom").Exists()) + // name/description 应保留 + assert.Equal(t, "tool1", tools[0].Get("name").String()) + assert.Equal(t, "desc1", tools[0].Get("description").String()) + // 没有 custom 的工具不受影响 + assert.Equal(t, "tool2", tools[1].Get("name").String()) + // 第三个工具的 custom 也应被移除 + assert.False(t, tools[2].Get("custom").Exists()) + assert.Equal(t, "tool3", tools[2].Get("name").String()) +} + +func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}]}` + result := removeCustomFieldFromTools([]byte(input)) + // 无 tools 时不改变原始数据 + assert.JSONEq(t, input, string(result)) +} + +func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) { + input := `{ + "system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}], + "messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}] + }` + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1") + + // scope 应被移除 + assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists()) + assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists()) + // type 应保留 + assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String()) + assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String()) +} + +func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) { + input := `{ + "system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}] + }` + // 旧模型(Claude 3.5)不支持 ttl + result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0") + + assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists()) + assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String()) +} + +func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) { + input := `{ + "system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}] + }` + // Claude 4.5+ 支持 "5m" 和 "1h" + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0") + + assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists()) + assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String()) +} + +func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) { + input := `{ + "system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}] + }` + // Claude 4.5 不支持 "10m" + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0") + + assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists()) +} + +func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) { + input := `{ + "messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}] + }` + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1") + + assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists()) + assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String()) +} + +func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) { + input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}` + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1") + // 无 cache_control 时不改变原始数据 + assert.JSONEq(t, input, string(result)) +} + +func TestIsBedrockClaude45OrNewer(t *testing.T) { + tests := []struct { + modelID string + expect bool + }{ + {"us.anthropic.claude-opus-4-6-v1", true}, + {"us.anthropic.claude-sonnet-4-6", true}, + {"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true}, + {"us.anthropic.claude-opus-4-5-20251101-v1:0", true}, + {"us.anthropic.claude-haiku-4-5-20251001-v1:0", true}, + {"anthropic.claude-3-5-sonnet-20241022-v2:0", false}, + {"anthropic.claude-3-opus-20240229-v1:0", false}, + {"anthropic.claude-3-haiku-20240307-v1:0", false}, + // 未来版本应自动支持 + {"us.anthropic.claude-sonnet-5-0-v1", true}, + {"us.anthropic.claude-opus-4-7-v1", true}, + // 旧版本 + {"anthropic.claude-opus-4-1-v1", false}, + {"anthropic.claude-sonnet-4-0-v1", false}, + // 非 Claude 模型 + {"amazon.nova-pro-v1", false}, + {"meta.llama3-70b", false}, + } + for _, tt := range tests { + t.Run(tt.modelID, func(t *testing.T) { + assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID)) + }) + } +} + +func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) { + // 模拟一个完整的 Claude Code 请求 + input := `{ + "model": "claude-opus-4-6", + "stream": true, + "max_tokens": 16384, + "output_format": {"type": "json", "schema": {"result": "string"}}, + "output_config": {"max_tokens": 100}, + "system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}], + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]} + ], + "tools": [ + {"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}}, + {"name": "read", "description": "Read file", "input_schema": {"type": "object"}} + ] + }` + + betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12" + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader) + require.NoError(t, err) + + // 基本字段 + assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String()) + assert.False(t, gjson.GetBytes(result, "model").Exists()) + assert.False(t, gjson.GetBytes(result, "stream").Exists()) + assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int()) + + // anthropic_beta 应包含所有 beta tokens + betaArr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, betaArr, 3) + assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String()) + assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String()) + assert.Equal(t, "compact-2026-01-12", betaArr[2].String()) + + // output_format 应被移除,schema 内联到最后一条 user message + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + assert.False(t, gjson.GetBytes(result, "output_config").Exists()) + // content 数组:原始 text block + 内联 schema block + contentArr := gjson.GetBytes(result, "messages.0.content").Array() + require.Len(t, contentArr, 2) + assert.Equal(t, "hello", contentArr[0].Get("text").String()) + assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`) + + // tools 中的 custom 应被移除 + assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists()) + assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String()) + assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String()) + + // cache_control: scope 应被移除,ttl 在 Claude 4.6 上保留合法值 + assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists()) + assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String()) + assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String()) + assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String()) +} + +func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}` + + t.Run("empty beta header", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "") + require.NoError(t, err) + assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists()) + }) + + t.Run("single beta token", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 1) + assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + }) + + t.Run("multiple beta tokens with spaces", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 2) + assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + assert.Equal(t, "context-1m-2025-08-07", arr[1].String()) + }) + + t.Run("json array beta header", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`) + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 2) + assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + assert.Equal(t, "context-1m-2025-08-07", arr[1].String()) + }) +} + +func TestParseAnthropicBetaHeader(t *testing.T) { + assert.Nil(t, parseAnthropicBetaHeader("")) + assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a")) + assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b")) + assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b ")) + assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c")) + assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`)) +} + +func TestFilterBedrockBetaTokens(t *testing.T) { + t.Run("supported tokens pass through", func(t *testing.T) { + tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"} + result := filterBedrockBetaTokens(tokens) + assert.Equal(t, tokens, result) + }) + + t.Run("unsupported tokens are filtered out", func(t *testing.T) { + tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"} + result := filterBedrockBetaTokens(tokens) + assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result) + }) + + t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) { + tokens := []string{"advanced-tool-use-2025-11-20"} + result := filterBedrockBetaTokens(tokens) + assert.Contains(t, result, "tool-search-tool-2025-10-19") + // tool-examples 自动关联 + assert.Contains(t, result, "tool-examples-2025-10-29") + }) + + t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) { + tokens := []string{"tool-search-tool-2025-10-19"} + result := filterBedrockBetaTokens(tokens) + assert.Contains(t, result, "tool-search-tool-2025-10-19") + assert.Contains(t, result, "tool-examples-2025-10-29") + }) + + t.Run("no duplication when tool-examples already present", func(t *testing.T) { + tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"} + result := filterBedrockBetaTokens(tokens) + count := 0 + for _, t := range result { + if t == "tool-examples-2025-10-29" { + count++ + } + } + assert.Equal(t, 1, count) + }) + + t.Run("empty input returns nil", func(t *testing.T) { + result := filterBedrockBetaTokens(nil) + assert.Nil(t, result) + }) + + t.Run("all unsupported returns nil", func(t *testing.T) { + result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"}) + assert.Nil(t, result) + }) + + t.Run("duplicate tokens are deduplicated", func(t *testing.T) { + tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"} + result := filterBedrockBetaTokens(tokens) + assert.Equal(t, []string{"context-1m-2025-08-07"}, result) + }) +} + +func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}` + + t.Run("unsupported beta tokens are filtered", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", + "interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 1) + assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + }) + + t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", + "advanced-tool-use-2025-11-20") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 2) + assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String()) + assert.Equal(t, "tool-examples-2025-10-29", arr[1].String()) + }) +} + +func TestBedrockCrossRegionPrefix(t *testing.T) { + tests := []struct { + region string + expect string + }{ + // US regions + {"us-east-1", "us"}, + {"us-east-2", "us"}, + {"us-west-1", "us"}, + {"us-west-2", "us"}, + // GovCloud + {"us-gov-east-1", "us-gov"}, + {"us-gov-west-1", "us-gov"}, + // EU regions + {"eu-west-1", "eu"}, + {"eu-west-2", "eu"}, + {"eu-west-3", "eu"}, + {"eu-central-1", "eu"}, + {"eu-central-2", "eu"}, + {"eu-north-1", "eu"}, + {"eu-south-1", "eu"}, + // APAC regions + {"ap-northeast-1", "jp"}, + {"ap-northeast-2", "apac"}, + {"ap-southeast-1", "apac"}, + {"ap-southeast-2", "au"}, + {"ap-south-1", "apac"}, + // Canada / South America fallback to us + {"ca-central-1", "us"}, + {"sa-east-1", "us"}, + // Unknown defaults to us + {"me-south-1", "us"}, + } + for _, tt := range tests { + t.Run(tt.region, func(t *testing.T) { + assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region)) + }) + } +} + +func TestResolveBedrockModelID(t *testing.T) { + t.Run("default alias resolves and adjusts region", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "eu-west-1", + }, + } + + modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5") + require.True(t, ok) + assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID) + }) + + t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "ap-southeast-2", + "model_mapping": map[string]any{ + "claude-*": "claude-opus-4-6", + }, + }, + } + + modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking") + require.True(t, ok) + assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID) + }) + + t.Run("force global rewrites anthropic regional model id", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "us-east-1", + "aws_force_global": "true", + "model_mapping": map[string]any{ + "claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6", + }, + }, + } + + modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6") + require.True(t, ok) + assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID) + }) + + t.Run("direct bedrock model id passes through", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "us-east-1", + }, + } + + modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0") + require.True(t, ok) + assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID) + }) + + t.Run("unsupported alias returns false", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "us-east-1", + }, + } + + _, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022") + assert.False(t, ok) + }) +} + +func TestAutoInjectBedrockBetaTokens(t *testing.T) { + t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) { + body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "interleaved-thinking-2025-05-14") + }) + + t.Run("no duplicate when already present", func(t *testing.T) { + body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1") + count := 0 + for _, t := range result { + if t == "interleaved-thinking-2025-05-14" { + count++ + } + } + assert.Equal(t, 1, count) + }) + + t.Run("inject computer-use when computer tool present", func(t *testing.T) { + body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "computer-use-2025-11-24") + }) + + t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) { + body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "advanced-tool-use-2025-11-20") + }) + + t.Run("inject advanced-tool-use for input examples", func(t *testing.T) { + body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "advanced-tool-use-2025-11-20") + }) + + t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) { + body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6") + // 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换 + assert.Contains(t, result, "tool-search-tool-2025-10-19") + assert.NotContains(t, result, "advanced-tool-use-2025-11-20") + }) + + t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) { + body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6") + // 混合场景使用 advanced-tool-use(后续由 filter 转换为 tool-search-tool) + assert.Contains(t, result, "advanced-tool-use-2025-11-20") + }) + + t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) { + body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0") + assert.NotContains(t, result, "advanced-tool-use-2025-11-20") + assert.NotContains(t, result, "tool-search-tool-2025-10-19") + }) + + t.Run("no injection for regular tools", func(t *testing.T) { + body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Empty(t, result) + }) + + t.Run("no injection when no features detected", func(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Empty(t, result) + }) + + t.Run("preserves existing tokens", func(t *testing.T) { + body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`) + existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"} + result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "context-1m-2025-08-07") + assert.Contains(t, result, "compact-2026-01-12") + assert.Contains(t, result, "interleaved-thinking-2025-05-14") + }) +} + +func TestResolveBedrockBetaTokens(t *testing.T) { + t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) { + body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`) + result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "tool-search-tool-2025-10-19") + assert.Contains(t, result, "tool-examples-2025-10-29") + }) + + t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`) + result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1") + assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result) + }) +} + +func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) { + t.Run("thinking in body auto-injects beta without header", func(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + found := false + for _, v := range arr { + if v.String() == "interleaved-thinking-2025-05-14" { + found = true + } + } + assert.True(t, found, "interleaved-thinking should be auto-injected") + }) + + t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + names := make([]string, len(arr)) + for i, v := range arr { + names[i] = v.String() + } + assert.Contains(t, names, "context-1m-2025-08-07") + assert.Contains(t, names, "interleaved-thinking-2025-05-14") + }) +} + +func TestAdjustBedrockModelRegionPrefix(t *testing.T) { + tests := []struct { + name string + modelID string + region string + expect string + }{ + // US region — no change needed + {"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"}, + // EU region — replace us → eu + {"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"}, + {"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"}, + // APAC region — jp and au have dedicated prefixes per AWS docs + {"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"}, + {"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"}, + {"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"}, + // eu → us (user manually set eu prefix, moved to us region) + {"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"}, + // global prefix — replace to match region + {"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"}, + // No known prefix — leave unchanged + {"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"}, + // GovCloud — uses independent us-gov prefix + {"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"}, + {"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"}, + // Force global (special region value) + {"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"}, + {"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region)) + }) + } +} diff --git a/backend/internal/service/bedrock_signer.go b/backend/internal/service/bedrock_signer.go new file mode 100644 index 00000000..e7000b4d --- /dev/null +++ b/backend/internal/service/bedrock_signer.go @@ -0,0 +1,67 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" +) + +// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名 +type BedrockSigner struct { + credentials aws.Credentials + region string + signer *v4.Signer +} + +// NewBedrockSigner 创建 BedrockSigner +func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner { + return &BedrockSigner{ + credentials: aws.Credentials{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + SessionToken: sessionToken, + }, + region: region, + signer: v4.NewSigner(), + } +} + +// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner +func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) { + accessKeyID := account.GetCredential("aws_access_key_id") + if accessKeyID == "" { + return nil, fmt.Errorf("aws_access_key_id not found in credentials") + } + secretAccessKey := account.GetCredential("aws_secret_access_key") + if secretAccessKey == "" { + return nil, fmt.Errorf("aws_secret_access_key not found in credentials") + } + region := account.GetCredential("aws_region") + if region == "" { + region = defaultBedrockRegion + } + sessionToken := account.GetCredential("aws_session_token") // 可选 + + return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil +} + +// SignRequest 对 HTTP 请求进行 SigV4 签名 +// 重要约束:调用此方法前,req 应只包含 AWS 相关的 header(如 Content-Type、Accept)。 +// 非 AWS header(如 anthropic-beta)会参与签名计算,如果 Bedrock 服务端不识别这些 header, +// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤, +// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept,因此是安全的。 +func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error { + payloadHash := sha256Hash(body) + return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now()) +} + +func sha256Hash(data []byte) string { + h := sha256.Sum256(data) + return hex.EncodeToString(h[:]) +} diff --git a/backend/internal/service/bedrock_signer_test.go b/backend/internal/service/bedrock_signer_test.go new file mode 100644 index 00000000..641e9341 --- /dev/null +++ b/backend/internal/service/bedrock_signer_test.go @@ -0,0 +1,35 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_access_key_id": "test-akid", + "aws_secret_access_key": "test-secret", + }, + } + + signer, err := NewBedrockSignerFromAccount(account) + require.NoError(t, err) + require.NotNil(t, signer) + assert.Equal(t, defaultBedrockRegion, signer.region) +} + +func TestFilterBetaTokens(t *testing.T) { + tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"} + filterSet := map[string]struct{}{ + "tool-search-tool-2025-10-19": {}, + } + + assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet)) + assert.Equal(t, tokens, filterBetaTokens(tokens, nil)) + assert.Nil(t, filterBetaTokens(nil, filterSet)) +} diff --git a/backend/internal/service/bedrock_stream.go b/backend/internal/service/bedrock_stream.go new file mode 100644 index 00000000..98196d27 --- /dev/null +++ b/backend/internal/service/bedrock_stream.go @@ -0,0 +1,414 @@ +package service + +import ( + "bufio" + "context" + "encoding/base64" + "errors" + "fmt" + "hash/crc32" + "io" + "net/http" + "sync/atomic" + "time" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应 +// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的 +// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。 +func (s *GatewayService) handleBedrockStreamingResponse( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, + model string, +) (*streamingResult, error) { + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-amzn-requestid"); v != "" { + c.Header("x-request-id", v) + } + + usage := &ClaudeUsage{} + var firstTokenMs *int + clientDisconnected := false + + // Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。 + // 每个帧结构:total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4) + // 但更实用的方式是使用行扫描找 JSON chunks,因为 Bedrock 的响应在二进制帧中。 + // 我们使用 EventStream decoder 来正确解析。 + decoder := newBedrockEventStreamDecoder(resp.Body) + + type decodeEvent struct { + payload []byte + err error + } + events := make(chan decodeEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev decodeEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt atomic.Int64 + lastReadAt.Store(time.Now().UnixNano()) + + go func() { + defer close(events) + for { + payload, err := decoder.Decode() + if err != nil { + if err == io.EOF { + return + } + _ = sendEvent(decodeEvent{err: err}) + return + } + lastReadAt.Store(time.Now().UnixNano()) + if !sendEvent(decodeEvent{payload: payload}) { + return + } + } + }() + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + for { + select { + case ev, ok := <-events: + if !ok { + if !clientDisconnected { + flusher.Flush() + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if ev.err != nil { + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err) + } + + // payload 是 JSON,提取 chunk.bytes(base64 编码的 Claude SSE 事件数据) + sseData := extractBedrockChunkData(ev.payload) + if sseData == nil { + continue + } + + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式 + // 同时移除该字段避免透传给客户端 + sseData = transformBedrockInvocationMetrics(sseData) + + // 解析 SSE 事件数据提取 usage + s.parseSSEUsagePassthrough(string(sseData), usage) + + // 确定 SSE event type + eventType := gjson.GetBytes(sseData, "type").String() + + // 写入标准 SSE 格式 + if !clientDisconnected { + var writeErr error + if eventType != "" { + _, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData) + } else { + _, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData) + } + if writeErr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID) + } else { + flusher.Flush() + } + } + + case <-intervalCh: + lastRead := time.Unix(0, lastReadAt.Load()) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, model) + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + } + } +} + +// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据 +// Bedrock payload 格式:{"bytes":""} +func extractBedrockChunkData(payload []byte) []byte { + b64 := gjson.GetBytes(payload, "bytes").String() + if b64 == "" { + return nil + } + decoded, err := base64.StdEncoding.DecodeString(b64) + if err != nil { + return nil + } + return decoded +} + +// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics +// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。 +// +// Bedrock Invoke 返回的 message_delta 事件可能包含: +// +// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}} +// +// 转换为: +// +// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}} +func transformBedrockInvocationMetrics(data []byte) []byte { + metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics") + if !metrics.Exists() || !metrics.IsObject() { + return data + } + + // 移除 Bedrock 特有字段 + data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics") + + // 如果已有标准 usage 字段,不覆盖 + if gjson.GetBytes(data, "usage").Exists() { + return data + } + + // 转换 camelCase → snake_case 写入 usage + inputTokens := metrics.Get("inputTokenCount") + outputTokens := metrics.Get("outputTokenCount") + if inputTokens.Exists() { + data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int()) + } + if outputTokens.Exists() { + data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int()) + } + + return data +} + +// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧 +// EventStream 帧格式: +// +// [total_byte_length: 4 bytes] +// [headers_byte_length: 4 bytes] +// [prelude_crc: 4 bytes] +// [headers: variable] +// [payload: variable] +// [message_crc: 4 bytes] +type bedrockEventStreamDecoder struct { + reader *bufio.Reader +} + +func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder { + return &bedrockEventStreamDecoder{ + reader: bufio.NewReaderSize(r, 64*1024), + } +} + +// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload +func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) { + for { + // 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes + prelude := make([]byte, 12) + if _, err := io.ReadFull(d.reader, prelude); err != nil { + return nil, err + } + + // 验证 prelude CRC(AWS EventStream 使用标准 CRC32 / IEEE) + preludeCRC := bedrockReadUint32(prelude[8:12]) + if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC { + return nil, fmt.Errorf("eventstream prelude CRC mismatch") + } + + totalLength := bedrockReadUint32(prelude[0:4]) + headersLength := bedrockReadUint32(prelude[4:8]) + + if totalLength < 16 { // minimum: 12 prelude + 4 message_crc + return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength) + } + + // 读取 headers + payload + message_crc + remaining := int(totalLength) - 12 + if remaining <= 0 { + continue + } + data := make([]byte, remaining) + if _, err := io.ReadFull(d.reader, data); err != nil { + return nil, err + } + + // 验证 message CRC(覆盖 prelude + headers + payload) + messageCRC := bedrockReadUint32(data[len(data)-4:]) + h := crc32.New(crc32IEEETable) + _, _ = h.Write(prelude) + _, _ = h.Write(data[:len(data)-4]) + if h.Sum32() != messageCRC { + return nil, fmt.Errorf("eventstream message CRC mismatch") + } + + // 解析 headers + headers := data[:headersLength] + payload := data[headersLength : len(data)-4] // 去掉 message_crc + + // 从 headers 中提取 :event-type + eventType := extractEventStreamHeaderValue(headers, ":event-type") + + // 只处理 chunk 事件 + if eventType == "chunk" { + // payload 是完整的 JSON,包含 bytes 字段 + return payload, nil + } + + // 检查异常事件 + exceptionType := extractEventStreamHeaderValue(headers, ":exception-type") + if exceptionType != "" { + return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload)) + } + + messageType := extractEventStreamHeaderValue(headers, ":message-type") + if messageType == "exception" || messageType == "error" { + return nil, fmt.Errorf("bedrock error: %s", string(payload)) + } + + // 跳过其他事件类型(如 initial-response) + } +} + +// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值 +// EventStream header 格式: +// +// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable] +// +// value_type = 7 表示 string 类型,前 2 bytes 为长度 +func extractEventStreamHeaderValue(headers []byte, targetName string) string { + pos := 0 + for pos < len(headers) { + if pos >= len(headers) { + break + } + nameLen := int(headers[pos]) + pos++ + if pos+nameLen > len(headers) { + break + } + name := string(headers[pos : pos+nameLen]) + pos += nameLen + + if pos >= len(headers) { + break + } + valueType := headers[pos] + pos++ + + switch valueType { + case 7: // string + if pos+2 > len(headers) { + return "" + } + valueLen := int(bedrockReadUint16(headers[pos : pos+2])) + pos += 2 + if pos+valueLen > len(headers) { + return "" + } + value := string(headers[pos : pos+valueLen]) + pos += valueLen + if name == targetName { + return value + } + case 0: // bool true + if name == targetName { + return "true" + } + case 1: // bool false + if name == targetName { + return "false" + } + case 2: // byte + pos++ + if name == targetName { + return "" + } + case 3: // short + pos += 2 + if name == targetName { + return "" + } + case 4: // int + pos += 4 + if name == targetName { + return "" + } + case 5: // long + pos += 8 + if name == targetName { + return "" + } + case 6: // bytes + if pos+2 > len(headers) { + return "" + } + valueLen := int(bedrockReadUint16(headers[pos : pos+2])) + pos += 2 + valueLen + case 8: // timestamp + pos += 8 + case 9: // uuid + pos += 16 + default: + return "" // 未知类型,无法继续解析 + } + } + return "" +} + +// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream. +var crc32IEEETable = crc32.MakeTable(crc32.IEEE) + +func bedrockReadUint32(b []byte) uint32 { + return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) +} + +func bedrockReadUint16(b []byte) uint16 { + return uint16(b[0])<<8 | uint16(b[1]) +} diff --git a/backend/internal/service/bedrock_stream_test.go b/backend/internal/service/bedrock_stream_test.go new file mode 100644 index 00000000..3d066137 --- /dev/null +++ b/backend/internal/service/bedrock_stream_test.go @@ -0,0 +1,261 @@ +package service + +import ( + "bytes" + "encoding/base64" + "encoding/binary" + "hash/crc32" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestExtractBedrockChunkData(t *testing.T) { + t.Run("valid base64 payload", func(t *testing.T) { + original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}` + b64 := base64.StdEncoding.EncodeToString([]byte(original)) + payload := []byte(`{"bytes":"` + b64 + `"}`) + + result := extractBedrockChunkData(payload) + require.NotNil(t, result) + assert.JSONEq(t, original, string(result)) + }) + + t.Run("empty bytes field", func(t *testing.T) { + result := extractBedrockChunkData([]byte(`{"bytes":""}`)) + assert.Nil(t, result) + }) + + t.Run("no bytes field", func(t *testing.T) { + result := extractBedrockChunkData([]byte(`{"other":"value"}`)) + assert.Nil(t, result) + }) + + t.Run("invalid base64", func(t *testing.T) { + result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`)) + assert.Nil(t, result) + }) +} + +func TestTransformBedrockInvocationMetrics(t *testing.T) { + t.Run("converts metrics to usage", func(t *testing.T) { + input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}` + result := transformBedrockInvocationMetrics([]byte(input)) + + // amazon-bedrock-invocationMetrics should be removed + assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists()) + // usage should be set + assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int()) + assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int()) + // original fields preserved + assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String()) + assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String()) + }) + + t.Run("no metrics present", func(t *testing.T) { + input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}` + result := transformBedrockInvocationMetrics([]byte(input)) + assert.JSONEq(t, input, string(result)) + }) + + t.Run("does not overwrite existing usage", func(t *testing.T) { + input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}` + result := transformBedrockInvocationMetrics([]byte(input)) + + // metrics removed but existing usage preserved + assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists()) + assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int()) + }) +} + +func TestExtractEventStreamHeaderValue(t *testing.T) { + // Build a header with :event-type = "chunk" (string type = 7) + buildStringHeader := func(name, value string) []byte { + var buf bytes.Buffer + // name length (1 byte) + _ = buf.WriteByte(byte(len(name))) + // name + _, _ = buf.WriteString(name) + // value type (7 = string) + _ = buf.WriteByte(7) + // value length (2 bytes, big-endian) + _ = binary.Write(&buf, binary.BigEndian, uint16(len(value))) + // value + _, _ = buf.WriteString(value) + return buf.Bytes() + } + + t.Run("find string header", func(t *testing.T) { + headers := buildStringHeader(":event-type", "chunk") + assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type")) + }) + + t.Run("header not found", func(t *testing.T) { + headers := buildStringHeader(":event-type", "chunk") + assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type")) + }) + + t.Run("multiple headers", func(t *testing.T) { + var buf bytes.Buffer + _, _ = buf.Write(buildStringHeader(":content-type", "application/json")) + _, _ = buf.Write(buildStringHeader(":event-type", "chunk")) + _, _ = buf.Write(buildStringHeader(":message-type", "event")) + + headers := buf.Bytes() + assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type")) + assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type")) + assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type")) + }) + + t.Run("empty headers", func(t *testing.T) { + assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type")) + }) +} + +func TestBedrockEventStreamDecoder(t *testing.T) { + crc32IeeeTab := crc32.MakeTable(crc32.IEEE) + + // Build a valid EventStream frame with correct CRC32/IEEE checksums. + buildFrame := func(eventType string, payload []byte) []byte { + // Build headers + var headersBuf bytes.Buffer + // :event-type header + _ = headersBuf.WriteByte(byte(len(":event-type"))) + _, _ = headersBuf.WriteString(":event-type") + _ = headersBuf.WriteByte(7) // string type + _ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType))) + _, _ = headersBuf.WriteString(eventType) + // :message-type header + _ = headersBuf.WriteByte(byte(len(":message-type"))) + _, _ = headersBuf.WriteString(":message-type") + _ = headersBuf.WriteByte(7) + _ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event"))) + _, _ = headersBuf.WriteString("event") + + headers := headersBuf.Bytes() + headersLen := uint32(len(headers)) + // total = 12 (prelude) + headers + payload + 4 (message_crc) + totalLen := uint32(12 + len(headers) + len(payload) + 4) + + // Prelude: total_length(4) + headers_length(4) + var preludeBuf bytes.Buffer + _ = binary.Write(&preludeBuf, binary.BigEndian, totalLen) + _ = binary.Write(&preludeBuf, binary.BigEndian, headersLen) + preludeBytes := preludeBuf.Bytes() + preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab) + + // Build frame: prelude + prelude_crc + headers + payload + var frame bytes.Buffer + _, _ = frame.Write(preludeBytes) + _ = binary.Write(&frame, binary.BigEndian, preludeCRC) + _, _ = frame.Write(headers) + _, _ = frame.Write(payload) + + // Message CRC covers everything before itself + messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab) + _ = binary.Write(&frame, binary.BigEndian, messageCRC) + return frame.Bytes() + } + + t.Run("decode chunk event", func(t *testing.T) { + payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test") + frame := buildFrame("chunk", payload) + + decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame)) + result, err := decoder.Decode() + require.NoError(t, err) + assert.Equal(t, payload, result) + }) + + t.Run("skip non-chunk events", func(t *testing.T) { + // Write initial-response followed by chunk + var buf bytes.Buffer + _, _ = buf.Write(buildFrame("initial-response", []byte(`{}`))) + chunkPayload := []byte(`{"bytes":"aGVsbG8="}`) + _, _ = buf.Write(buildFrame("chunk", chunkPayload)) + + decoder := newBedrockEventStreamDecoder(&buf) + result, err := decoder.Decode() + require.NoError(t, err) + assert.Equal(t, chunkPayload, result) + }) + + t.Run("EOF on empty input", func(t *testing.T) { + decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil)) + _, err := decoder.Decode() + assert.Equal(t, io.EOF, err) + }) + + t.Run("corrupted prelude CRC", func(t *testing.T) { + frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`)) + // Corrupt the prelude CRC (bytes 8-11) + frame[8] ^= 0xFF + decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame)) + _, err := decoder.Decode() + require.Error(t, err) + assert.Contains(t, err.Error(), "prelude CRC mismatch") + }) + + t.Run("corrupted message CRC", func(t *testing.T) { + frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`)) + // Corrupt the message CRC (last 4 bytes) + frame[len(frame)-1] ^= 0xFF + decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame)) + _, err := decoder.Decode() + require.Error(t, err) + assert.Contains(t, err.Error(), "message CRC mismatch") + }) + + t.Run("castagnoli encoded frame is rejected", func(t *testing.T) { + castagnoliTab := crc32.MakeTable(crc32.Castagnoli) + payload := []byte(`{"bytes":"dGVzdA=="}`) + + var headersBuf bytes.Buffer + _ = headersBuf.WriteByte(byte(len(":event-type"))) + _, _ = headersBuf.WriteString(":event-type") + _ = headersBuf.WriteByte(7) + _ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk"))) + _, _ = headersBuf.WriteString("chunk") + + headers := headersBuf.Bytes() + headersLen := uint32(len(headers)) + totalLen := uint32(12 + len(headers) + len(payload) + 4) + + var preludeBuf bytes.Buffer + _ = binary.Write(&preludeBuf, binary.BigEndian, totalLen) + _ = binary.Write(&preludeBuf, binary.BigEndian, headersLen) + preludeBytes := preludeBuf.Bytes() + + var frame bytes.Buffer + _, _ = frame.Write(preludeBytes) + _ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab)) + _, _ = frame.Write(headers) + _, _ = frame.Write(payload) + _ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab)) + + decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes())) + _, err := decoder.Decode() + require.Error(t, err) + assert.Contains(t, err.Error(), "prelude CRC mismatch") + }) +} + +func TestBuildBedrockURL(t *testing.T) { + t.Run("stream URL with colon in model ID", func(t *testing.T) { + url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true) + assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url) + }) + + t.Run("non-stream URL with colon in model ID", func(t *testing.T) { + url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false) + assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url) + }) + + t.Run("model ID without colon", func(t *testing.T) { + url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true) + assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url) + }) +} diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 1a76f5f6..f2ad0a3d 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -40,6 +40,7 @@ const ( cacheWriteSetSubscription cacheWriteUpdateSubscriptionUsage cacheWriteDeductBalance + cacheWriteUpdateRateLimitUsage ) // 异步缓存写入工作池配置 @@ -68,19 +69,26 @@ type cacheWriteTask struct { kind cacheWriteKind userID int64 groupID int64 + apiKeyID int64 balance float64 amount float64 subscriptionData *subscriptionCacheData } +// apiKeyRateLimitLoader defines the interface for loading rate limit data from DB. +type apiKeyRateLimitLoader interface { + GetRateLimitData(ctx context.Context, keyID int64) (*APIKeyRateLimitData, error) +} + // BillingCacheService 计费缓存服务 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 type BillingCacheService struct { - cache BillingCache - userRepo UserRepository - subRepo UserSubscriptionRepository - cfg *config.Config - circuitBreaker *billingCircuitBreaker + cache BillingCache + userRepo UserRepository + subRepo UserSubscriptionRepository + apiKeyRateLimitLoader apiKeyRateLimitLoader + cfg *config.Config + circuitBreaker *billingCircuitBreaker cacheWriteChan chan cacheWriteTask cacheWriteWg sync.WaitGroup @@ -96,12 +104,13 @@ type BillingCacheService struct { } // NewBillingCacheService 创建计费缓存服务 -func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService { +func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService { svc := &BillingCacheService{ - cache: cache, - userRepo: userRepo, - subRepo: subRepo, - cfg: cfg, + cache: cache, + userRepo: userRepo, + subRepo: subRepo, + apiKeyRateLimitLoader: apiKeyRepo, + cfg: cfg, } svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) svc.startCacheWriteWorkers() @@ -188,6 +197,12 @@ func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) { logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err) } } + case cacheWriteUpdateRateLimitUsage: + if s.cache != nil { + if err := s.cache.UpdateAPIKeyRateLimitUsage(ctx, task.apiKeyID, task.amount); err != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: update rate limit usage cache failed for api key %d: %v", task.apiKeyID, err) + } + } } cancel() } @@ -204,6 +219,8 @@ func cacheWriteKindName(kind cacheWriteKind) string { return "update_subscription_usage" case cacheWriteDeductBalance: return "deduct_balance" + case cacheWriteUpdateRateLimitUsage: + return "update_rate_limit_usage" default: return "unknown" } @@ -476,6 +493,141 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID return nil } +// ============================================ +// API Key 限速缓存方法 +// ============================================ + +// checkAPIKeyRateLimits checks rate limit windows for an API key. +// It loads usage from Redis cache (falling back to DB on cache miss), +// resets expired windows in-memory and triggers async DB reset, +// and returns an error if any window limit is exceeded. +func (s *BillingCacheService) checkAPIKeyRateLimits(ctx context.Context, apiKey *APIKey) error { + if s.cache == nil { + // No cache: fall back to reading from DB directly + if s.apiKeyRateLimitLoader == nil { + return nil + } + data, err := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID) + if err != nil { + return nil // Don't block requests on DB errors + } + return s.evaluateRateLimits(ctx, apiKey, data.Usage5h, data.Usage1d, data.Usage7d, + data.Window5hStart, data.Window1dStart, data.Window7dStart) + } + + cacheData, err := s.cache.GetAPIKeyRateLimit(ctx, apiKey.ID) + if err != nil { + // Cache miss: load from DB and populate cache + if s.apiKeyRateLimitLoader == nil { + return nil + } + dbData, dbErr := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID) + if dbErr != nil { + return nil // Don't block requests on DB errors + } + // Build cache entry from DB data + cacheEntry := &APIKeyRateLimitCacheData{ + Usage5h: dbData.Usage5h, + Usage1d: dbData.Usage1d, + Usage7d: dbData.Usage7d, + } + if dbData.Window5hStart != nil { + cacheEntry.Window5h = dbData.Window5hStart.Unix() + } + if dbData.Window1dStart != nil { + cacheEntry.Window1d = dbData.Window1dStart.Unix() + } + if dbData.Window7dStart != nil { + cacheEntry.Window7d = dbData.Window7dStart.Unix() + } + _ = s.cache.SetAPIKeyRateLimit(ctx, apiKey.ID, cacheEntry) + cacheData = cacheEntry + } + + var w5h, w1d, w7d *time.Time + if cacheData.Window5h > 0 { + t := time.Unix(cacheData.Window5h, 0) + w5h = &t + } + if cacheData.Window1d > 0 { + t := time.Unix(cacheData.Window1d, 0) + w1d = &t + } + if cacheData.Window7d > 0 { + t := time.Unix(cacheData.Window7d, 0) + w7d = &t + } + return s.evaluateRateLimits(ctx, apiKey, cacheData.Usage5h, cacheData.Usage1d, cacheData.Usage7d, w5h, w1d, w7d) +} + +// evaluateRateLimits checks usage against limits, triggering async resets for expired windows. +func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *APIKey, usage5h, usage1d, usage7d float64, w5h, w1d, w7d *time.Time) error { + needsReset := false + + // Reset expired windows in-memory for check purposes + if IsWindowExpired(w5h, RateLimitWindow5h) { + usage5h = 0 + needsReset = true + } + if IsWindowExpired(w1d, RateLimitWindow1d) { + usage1d = 0 + needsReset = true + } + if IsWindowExpired(w7d, RateLimitWindow7d) { + usage7d = 0 + needsReset = true + } + + // Trigger async DB reset if any window expired + if needsReset { + keyID := apiKey.ID + go func() { + resetCtx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) + defer cancel() + if s.apiKeyRateLimitLoader != nil { + // Use the repo directly - reset then reload cache + if loader, ok := s.apiKeyRateLimitLoader.(interface { + ResetRateLimitWindows(ctx context.Context, id int64) error + }); ok { + if err := loader.ResetRateLimitWindows(resetCtx, keyID); err != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: reset rate limit windows failed for api key %d: %v", keyID, err) + } + } + } + // Invalidate cache so next request loads fresh data + if s.cache != nil { + if err := s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID); err != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: invalidate rate limit cache failed for api key %d: %v", keyID, err) + } + } + }() + } + + // Check limits + if apiKey.RateLimit5h > 0 && usage5h >= apiKey.RateLimit5h { + return ErrAPIKeyRateLimit5hExceeded + } + if apiKey.RateLimit1d > 0 && usage1d >= apiKey.RateLimit1d { + return ErrAPIKeyRateLimit1dExceeded + } + if apiKey.RateLimit7d > 0 && usage7d >= apiKey.RateLimit7d { + return ErrAPIKeyRateLimit7dExceeded + } + return nil +} + +// QueueUpdateAPIKeyRateLimitUsage asynchronously updates rate limit usage in the cache. +func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, cost float64) { + if s.cache == nil { + return + } + s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteUpdateRateLimitUsage, + apiKeyID: apiKeyID, + amount: cost, + }) +} + // ============================================ // 统一检查方法 // ============================================ @@ -496,10 +648,23 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil if isSubscriptionMode { - return s.checkSubscriptionEligibility(ctx, user.ID, group, subscription) + if err := s.checkSubscriptionEligibility(ctx, user.ID, group, subscription); err != nil { + return err + } + } else { + if err := s.checkBalanceEligibility(ctx, user.ID); err != nil { + return err + } } - return s.checkBalanceEligibility(ctx, user.ID) + // Check API Key rate limits (applies to both billing modes) + if apiKey != nil && apiKey.HasRateLimits() { + if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil { + return err + } + } + + return nil } // checkBalanceEligibility 检查余额模式资格 diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go index 1b12c402..4a8b8f03 100644 --- a/backend/internal/service/billing_cache_service_singleflight_test.go +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -51,6 +51,22 @@ func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, return nil } +func (s *billingCacheMissStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) { + return nil, errors.New("cache miss") +} + +func (s *billingCacheMissStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error { + return nil +} + +func (s *billingCacheMissStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { + return nil +} + +func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { + return nil +} + type balanceLoadUserRepoStub struct { mockUserRepo calls atomic.Int64 @@ -76,7 +92,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { delay: 80 * time.Millisecond, balance: 12.34, } - svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{}) + svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{}) t.Cleanup(svc.Stop) const goroutines = 16 diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go index 4e5f50e2..7d7045e2 100644 --- a/backend/internal/service/billing_cache_service_test.go +++ b/backend/internal/service/billing_cache_service_test.go @@ -52,9 +52,25 @@ func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context return nil } +func (b *billingCacheWorkerStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) { + return nil, errors.New("not implemented") +} + +func (b *billingCacheWorkerStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error { + return nil +} + +func (b *billingCacheWorkerStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { + return nil +} + +func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { + return nil +} + func TestBillingCacheServiceQueueHighLoad(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) t.Cleanup(svc.Stop) start := time.Now() @@ -76,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) svc.Stop() enqueued := svc.enqueueCacheWrite(cacheWriteTask{ diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 6abd1e53..68d7a8f9 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -10,6 +10,16 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" ) +// APIKeyRateLimitCacheData holds rate limit usage data cached in Redis. +type APIKeyRateLimitCacheData struct { + Usage5h float64 `json:"usage_5h"` + Usage1d float64 `json:"usage_1d"` + Usage7d float64 `json:"usage_7d"` + Window5h int64 `json:"window_5h"` // unix timestamp, 0 = not started + Window1d int64 `json:"window_1d"` + Window7d int64 `json:"window_7d"` +} + // BillingCache defines cache operations for billing service type BillingCache interface { // Balance operations @@ -23,17 +33,57 @@ type BillingCache interface { SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error + + // API Key rate limit operations + GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) + SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error + UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error + InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error } // ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致) type ModelPricing struct { - InputPricePerToken float64 // 每token输入价格 (USD) - OutputPricePerToken float64 // 每token输出价格 (USD) - CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) - CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) - CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD) - CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD) - SupportsCacheBreakdown bool // 是否支持详细的缓存分类 + InputPricePerToken float64 // 每token输入价格 (USD) + InputPricePerTokenPriority float64 // priority service tier 下每token输入价格 (USD) + OutputPricePerToken float64 // 每token输出价格 (USD) + OutputPricePerTokenPriority float64 // priority service tier 下每token输出价格 (USD) + CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) + CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) + CacheReadPricePerTokenPriority float64 // priority service tier 下缓存读取每token价格 (USD) + CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD) + CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD) + SupportsCacheBreakdown bool // 是否支持详细的缓存分类 + LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格 + LongContextInputMultiplier float64 // 长上下文整次会话输入倍率 + LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率 +} + +const ( + openAIGPT54LongContextInputThreshold = 272000 + openAIGPT54LongContextInputMultiplier = 2.0 + openAIGPT54LongContextOutputMultiplier = 1.5 +) + +func normalizeBillingServiceTier(serviceTier string) string { + return strings.ToLower(strings.TrimSpace(serviceTier)) +} + +func usePriorityServiceTierPricing(serviceTier string, pricing *ModelPricing) bool { + if pricing == nil || normalizeBillingServiceTier(serviceTier) != "priority" { + return false + } + return pricing.InputPricePerTokenPriority > 0 || pricing.OutputPricePerTokenPriority > 0 || pricing.CacheReadPricePerTokenPriority > 0 +} + +func serviceTierCostMultiplier(serviceTier string) float64 { + switch normalizeBillingServiceTier(serviceTier) { + case "priority": + return 2.0 + case "flex": + return 0.5 + default: + return 1.0 + } } // UsageTokens 使用的token数量 @@ -145,6 +195,65 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok SupportsCacheBreakdown: false, } + + // OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费) + s.fallbackPrices["gpt-5.1"] = &ModelPricing{ + InputPricePerToken: 1.25e-6, // $1.25 per MTok + InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok + OutputPricePerToken: 10e-6, // $10 per MTok + OutputPricePerTokenPriority: 20e-6, // $20 per MTok + CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok + CacheReadPricePerToken: 0.125e-6, + CacheReadPricePerTokenPriority: 0.25e-6, + SupportsCacheBreakdown: false, + } + // OpenAI GPT-5.4(业务指定价格) + s.fallbackPrices["gpt-5.4"] = &ModelPricing{ + InputPricePerToken: 2.5e-6, // $2.5 per MTok + InputPricePerTokenPriority: 5e-6, // $5 per MTok + OutputPricePerToken: 15e-6, // $15 per MTok + OutputPricePerTokenPriority: 30e-6, // $30 per MTok + CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok + CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok + CacheReadPricePerTokenPriority: 0.5e-6, // $0.5 per MTok + SupportsCacheBreakdown: false, + LongContextInputThreshold: openAIGPT54LongContextInputThreshold, + LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier, + LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier, + } + // OpenAI GPT-5.2(本地兜底) + s.fallbackPrices["gpt-5.2"] = &ModelPricing{ + InputPricePerToken: 1.75e-6, + InputPricePerTokenPriority: 3.5e-6, + OutputPricePerToken: 14e-6, + OutputPricePerTokenPriority: 28e-6, + CacheCreationPricePerToken: 1.75e-6, + CacheReadPricePerToken: 0.175e-6, + CacheReadPricePerTokenPriority: 0.35e-6, + SupportsCacheBreakdown: false, + } + // Codex 族兜底统一按 GPT-5.1 Codex 价格计费 + s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{ + InputPricePerToken: 1.5e-6, // $1.5 per MTok + InputPricePerTokenPriority: 3e-6, // $3 per MTok + OutputPricePerToken: 12e-6, // $12 per MTok + OutputPricePerTokenPriority: 24e-6, // $24 per MTok + CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok + CacheReadPricePerToken: 0.15e-6, + CacheReadPricePerTokenPriority: 0.3e-6, + SupportsCacheBreakdown: false, + } + s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{ + InputPricePerToken: 1.75e-6, + InputPricePerTokenPriority: 3.5e-6, + OutputPricePerToken: 14e-6, + OutputPricePerTokenPriority: 28e-6, + CacheCreationPricePerToken: 1.75e-6, + CacheReadPricePerToken: 0.175e-6, + CacheReadPricePerTokenPriority: 0.35e-6, + SupportsCacheBreakdown: false, + } + s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"] } // getFallbackPricing 根据模型系列获取回退价格 @@ -173,12 +282,34 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { } return s.fallbackPrices["claude-3-haiku"] } + // Claude 未知型号统一回退到 Sonnet,避免计费中断。 + if strings.Contains(modelLower, "claude") { + return s.fallbackPrices["claude-sonnet-4"] + } if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") { return s.fallbackPrices["gemini-3.1-pro"] } - // 默认使用Sonnet价格 - return s.fallbackPrices["claude-sonnet-4"] + // OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。 + if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") { + normalized := normalizeCodexModel(modelLower) + switch normalized { + case "gpt-5.4": + return s.fallbackPrices["gpt-5.4"] + case "gpt-5.2": + return s.fallbackPrices["gpt-5.2"] + case "gpt-5.2-codex": + return s.fallbackPrices["gpt-5.2-codex"] + case "gpt-5.3-codex": + return s.fallbackPrices["gpt-5.3-codex"] + case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest": + return s.fallbackPrices["gpt-5.1-codex"] + case "gpt-5.1": + return s.fallbackPrices["gpt-5.1"] + } + } + + return nil } // GetModelPricing 获取模型价格配置 @@ -196,15 +327,21 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { price5m := litellmPricing.CacheCreationInputTokenCost price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr enableBreakdown := price1h > 0 && price1h > price5m - return &ModelPricing{ - InputPricePerToken: litellmPricing.InputCostPerToken, - OutputPricePerToken: litellmPricing.OutputCostPerToken, - CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, - CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, - CacheCreation5mPrice: price5m, - CacheCreation1hPrice: price1h, - SupportsCacheBreakdown: enableBreakdown, - }, nil + return s.applyModelSpecificPricingPolicy(model, &ModelPricing{ + InputPricePerToken: litellmPricing.InputCostPerToken, + InputPricePerTokenPriority: litellmPricing.InputCostPerTokenPriority, + OutputPricePerToken: litellmPricing.OutputCostPerToken, + OutputPricePerTokenPriority: litellmPricing.OutputCostPerTokenPriority, + CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, + CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, + CacheReadPricePerTokenPriority: litellmPricing.CacheReadInputTokenCostPriority, + CacheCreation5mPrice: price5m, + CacheCreation1hPrice: price1h, + SupportsCacheBreakdown: enableBreakdown, + LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold, + LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier, + LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier, + }), nil } } @@ -212,7 +349,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { fallback := s.getFallbackPricing(model) if fallback != nil { log.Printf("[Billing] Using fallback pricing for model: %s", model) - return fallback, nil + return s.applyModelSpecificPricingPolicy(model, fallback), nil } return nil, fmt.Errorf("pricing not found for model: %s", model) @@ -220,18 +357,43 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { // CalculateCost 计算使用费用 func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) { + return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "") +} + +func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) { pricing, err := s.GetModelPricing(model) if err != nil { return nil, err } breakdown := &CostBreakdown{} + inputPricePerToken := pricing.InputPricePerToken + outputPricePerToken := pricing.OutputPricePerToken + cacheReadPricePerToken := pricing.CacheReadPricePerToken + tierMultiplier := 1.0 + if usePriorityServiceTierPricing(serviceTier, pricing) { + if pricing.InputPricePerTokenPriority > 0 { + inputPricePerToken = pricing.InputPricePerTokenPriority + } + if pricing.OutputPricePerTokenPriority > 0 { + outputPricePerToken = pricing.OutputPricePerTokenPriority + } + if pricing.CacheReadPricePerTokenPriority > 0 { + cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority + } + } else { + tierMultiplier = serviceTierCostMultiplier(serviceTier) + } + if s.shouldApplySessionLongContextPricing(tokens, pricing) { + inputPricePerToken *= pricing.LongContextInputMultiplier + outputPricePerToken *= pricing.LongContextOutputMultiplier + } // 计算输入token费用(使用per-token价格) - breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken + breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken // 计算输出token费用 - breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken + breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken // 计算缓存费用 if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { @@ -248,7 +410,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken } - breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken + breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken + + if tierMultiplier != 1.0 { + breakdown.InputCost *= tierMultiplier + breakdown.OutputCost *= tierMultiplier + breakdown.CacheCreationCost *= tierMultiplier + breakdown.CacheReadCost *= tierMultiplier + } // 计算总费用 breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + @@ -263,6 +432,45 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul return breakdown, nil } +func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing { + if pricing == nil { + return nil + } + if !isOpenAIGPT54Model(model) { + return pricing + } + if pricing.LongContextInputThreshold > 0 && pricing.LongContextInputMultiplier > 0 && pricing.LongContextOutputMultiplier > 0 { + return pricing + } + cloned := *pricing + if cloned.LongContextInputThreshold <= 0 { + cloned.LongContextInputThreshold = openAIGPT54LongContextInputThreshold + } + if cloned.LongContextInputMultiplier <= 0 { + cloned.LongContextInputMultiplier = openAIGPT54LongContextInputMultiplier + } + if cloned.LongContextOutputMultiplier <= 0 { + cloned.LongContextOutputMultiplier = openAIGPT54LongContextOutputMultiplier + } + return &cloned +} + +func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens, pricing *ModelPricing) bool { + if pricing == nil || pricing.LongContextInputThreshold <= 0 { + return false + } + if pricing.LongContextInputMultiplier <= 1 && pricing.LongContextOutputMultiplier <= 1 { + return false + } + totalInputTokens := tokens.InputTokens + tokens.CacheReadTokens + return totalInputTokens > pricing.LongContextInputThreshold +} + +func isOpenAIGPT54Model(model string) bool { + normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model))) + return normalized == "gpt-5.4" +} + // CalculateCostWithConfig 使用配置中的默认倍率计算费用 func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) { multiplier := s.cfg.Default.RateMultiplier diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 5eb278f6..45bbdcee 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -133,7 +133,7 @@ func TestGetModelPricing_CaseInsensitive(t *testing.T) { require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken) } -func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) { +func TestGetModelPricing_UnknownClaudeModelFallsBackToSonnet(t *testing.T) { svc := newTestBillingService() // 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格 @@ -142,6 +142,93 @@ func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) { require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) } +func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-unknown-model") + require.Error(t, err) + require.Nil(t, pricing) + require.Contains(t, err.Error(), "pricing not found") +} + +func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-5.1") + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12) +} + +func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-5.4") + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12) + require.Equal(t, 272000, pricing.LongContextInputThreshold) + require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12) + require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12) +} + +func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 300000, + OutputTokens: 4000, + } + + cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0) + require.NoError(t, err) + + expectedInput := float64(tokens.InputTokens) * 2.5e-6 * 2.0 + expectedOutput := float64(tokens.OutputTokens) * 15e-6 * 1.5 + require.InDelta(t, expectedInput, cost.InputCost, 1e-10) + require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10) +} + +func TestGetFallbackPricing_FamilyMatching(t *testing.T) { + svc := newTestBillingService() + + tests := []struct { + name string + model string + expectedInput float64 + expectNilPricing bool + }{ + {name: "empty model", model: " ", expectNilPricing: true}, + {name: "claude opus 4.6", model: "claude-opus-4.6-20260201", expectedInput: 5e-6}, + {name: "claude opus 4.5 alt separator", model: "claude-opus-4-5-20260101", expectedInput: 5e-6}, + {name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6}, + {name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6}, + {name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true}, + {name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6}, + {name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6}, + {name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6}, + {name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6}, + {name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6}, + {name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true}, + {name: "non supported family", model: "qwen-max", expectNilPricing: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pricing := svc.getFallbackPricing(tt.model) + if tt.expectNilPricing { + require.Nil(t, pricing) + return + } + require.NotNil(t, pricing) + require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12) + }) + } +} func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) { svc := newTestBillingService() @@ -435,3 +522,189 @@ func TestCalculateCost_LargeTokenCount(t *testing.T) { require.False(t, math.IsNaN(cost.TotalCost)) require.False(t, math.IsInf(cost.TotalCost, 0)) } + +func TestServiceTierCostMultiplier(t *testing.T) { + require.InDelta(t, 2.0, serviceTierCostMultiplier("priority"), 1e-12) + require.InDelta(t, 2.0, serviceTierCostMultiplier(" Priority "), 1e-12) + require.InDelta(t, 0.5, serviceTierCostMultiplier("flex"), 1e-12) + require.InDelta(t, 1.0, serviceTierCostMultiplier(""), 1e-12) + require.InDelta(t, 1.0, serviceTierCostMultiplier("default"), 1e-12) +} + +func TestCalculateCostWithServiceTier_OpenAIPriorityUsesPriorityPricing(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheReadTokens: 20} + + baseCost, err := svc.CalculateCost("gpt-5.1-codex", tokens, 1.0) + require.NoError(t, err) + + priorityCost, err := svc.CalculateCostWithServiceTier("gpt-5.1-codex", tokens, 1.0, "priority") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10) +} + +func TestCalculateCostWithServiceTier_FlexAppliesHalfMultiplier(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20} + + baseCost, err := svc.CalculateCost("gpt-5.4", tokens, 1.0) + require.NoError(t, err) + + flexCost, err := svc.CalculateCostWithServiceTier("gpt-5.4", tokens, 1.0, "flex") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*0.5, flexCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*0.5, flexCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*0.5, flexCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*0.5, flexCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10) +} + +func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWithoutExplicitPriorityPrice(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8} + + baseCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + priorityCost, err := svc.CalculateCostWithServiceTier("claude-sonnet-4", tokens, 1.0, "priority") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10) +} + +func TestBillingServiceGetModelPricing_UsesDynamicPriorityFields(t *testing.T) { + pricingSvc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.4": { + InputCostPerToken: 2.5e-6, + InputCostPerTokenPriority: 5e-6, + OutputCostPerToken: 15e-6, + OutputCostPerTokenPriority: 30e-6, + CacheCreationInputTokenCost: 2.5e-6, + CacheReadInputTokenCost: 0.25e-6, + CacheReadInputTokenCostPriority: 0.5e-6, + LongContextInputTokenThreshold: 272000, + LongContextInputCostMultiplier: 2.0, + LongContextOutputCostMultiplier: 1.5, + }, + }, + } + svc := NewBillingService(&config.Config{}, pricingSvc) + + pricing, err := svc.GetModelPricing("gpt-5.4") + require.NoError(t, err) + require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 30e-6, pricing.OutputPricePerTokenPriority, 1e-12) + require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12) + require.InDelta(t, 0.5e-6, pricing.CacheReadPricePerTokenPriority, 1e-12) + require.Equal(t, 272000, pricing.LongContextInputThreshold) + require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12) + require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12) +} + +func TestBillingServiceGetModelPricing_OpenAIFallbackGpt52Variants(t *testing.T) { + svc := newTestBillingService() + + gpt52, err := svc.GetModelPricing("gpt-5.2") + require.NoError(t, err) + require.NotNil(t, gpt52) + require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12) + require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12) + + gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex") + require.NoError(t, err) + require.NotNil(t, gpt52Codex) + require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12) + require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12) +} + +func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWhenExplicitPriceMissing(t *testing.T) { + svc := NewBillingService(&config.Config{}, &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "custom-no-priority": { + InputCostPerToken: 1e-6, + OutputCostPerToken: 2e-6, + CacheCreationInputTokenCost: 0.5e-6, + CacheReadInputTokenCost: 0.25e-6, + }, + }, + }) + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20} + + baseCost, err := svc.CalculateCost("custom-no-priority", tokens, 1.0) + require.NoError(t, err) + + priorityCost, err := svc.CalculateCostWithServiceTier("custom-no-priority", tokens, 1.0, "priority") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10) +} + +func TestGetModelPricing_OpenAIGpt52FallbacksExposePriorityPrices(t *testing.T) { + svc := newTestBillingService() + + gpt52, err := svc.GetModelPricing("gpt-5.2") + require.NoError(t, err) + require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12) + require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 14e-6, gpt52.OutputPricePerToken, 1e-12) + require.InDelta(t, 28e-6, gpt52.OutputPricePerTokenPriority, 1e-12) + + gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex") + require.NoError(t, err) + require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12) + require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 14e-6, gpt52Codex.OutputPricePerToken, 1e-12) + require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12) +} + +func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.T) { + svc := NewBillingService(&config.Config{}, &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "dynamic-tier-model": { + InputCostPerToken: 1e-6, + InputCostPerTokenPriority: 2e-6, + OutputCostPerToken: 3e-6, + OutputCostPerTokenPriority: 6e-6, + CacheCreationInputTokenCost: 4e-6, + CacheCreationInputTokenCostAbove1hr: 5e-6, + CacheReadInputTokenCost: 7e-7, + CacheReadInputTokenCostPriority: 8e-7, + LongContextInputTokenThreshold: 999, + LongContextInputCostMultiplier: 1.5, + LongContextOutputCostMultiplier: 1.25, + }, + }, + }) + + pricing, err := svc.GetModelPricing("dynamic-tier-model") + require.NoError(t, err) + require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 2e-6, pricing.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 3e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 6e-6, pricing.OutputPricePerTokenPriority, 1e-12) + require.InDelta(t, 4e-6, pricing.CacheCreation5mPrice, 1e-12) + require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12) + require.True(t, pricing.SupportsCacheBreakdown) + require.InDelta(t, 7e-7, pricing.CacheReadPricePerToken, 1e-12) + require.InDelta(t, 8e-7, pricing.CacheReadPricePerTokenPriority, 1e-12) + require.Equal(t, 999, pricing.LongContextInputThreshold) + require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12) + require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12) +} diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go index f6cab204..82fa31c4 100644 --- a/backend/internal/service/claude_token_provider.go +++ b/backend/internal/service/claude_token_provider.go @@ -4,7 +4,6 @@ import ( "context" "errors" "log/slog" - "strconv" "strings" "time" ) @@ -15,14 +14,17 @@ const ( claudeLockWaitTime = 200 * time.Millisecond ) -// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) +// ClaudeTokenCache token cache interface. type ClaudeTokenCache = GeminiTokenCache -// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token +// ClaudeTokenProvider manages access_token for Claude OAuth accounts. type ClaudeTokenProvider struct { - accountRepo AccountRepository - tokenCache ClaudeTokenCache - oauthService *OAuthService + accountRepo AccountRepository + tokenCache ClaudeTokenCache + oauthService *OAuthService + refreshAPI *OAuthRefreshAPI + executor OAuthRefreshExecutor + refreshPolicy ProviderRefreshPolicy } func NewClaudeTokenProvider( @@ -31,13 +33,25 @@ func NewClaudeTokenProvider( oauthService *OAuthService, ) *ClaudeTokenProvider { return &ClaudeTokenProvider{ - accountRepo: accountRepo, - tokenCache: tokenCache, - oauthService: oauthService, + accountRepo: accountRepo, + tokenCache: tokenCache, + oauthService: oauthService, + refreshPolicy: ClaudeProviderRefreshPolicy(), } } -// GetAccessToken 获取有效的 access_token +// SetRefreshAPI injects unified OAuth refresh API and executor. +func (p *ClaudeTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) { + p.refreshAPI = api + p.executor = executor +} + +// SetRefreshPolicy injects caller-side refresh policy. +func (p *ClaudeTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) { + p.refreshPolicy = policy +} + +// GetAccessToken returns a valid access_token. func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { if account == nil { return "", errors.New("account is nil") @@ -48,7 +62,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou cacheKey := ClaudeTokenCacheKey(account) - // 1. 先尝试缓存 + // 1) Try cache first. if p.tokenCache != nil { if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { slog.Debug("claude_token_cache_hit", "account_id", account.ID) @@ -60,114 +74,39 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou slog.Debug("claude_token_cache_miss", "account_id", account.ID) - // 2. 如果即将过期则刷新 + // 2) Refresh if needed (pre-expiry skew). expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew refreshFailed := false - if needsRefresh && p.tokenCache != nil { - locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) - if lockErr == nil && locked { - defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() - // 拿到锁后再次检查缓存(另一个 worker 可能已刷新) - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { - return token, nil + if needsRefresh && p.refreshAPI != nil && p.executor != nil { + result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, claudeTokenRefreshSkew) + if err != nil { + if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { + return "", err } - - // 从数据库获取最新账户信息 - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh - } - expiresAt = account.GetCredentialAsTime("expires_at") - if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew { - if p.oauthService == nil { - slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID) - refreshFailed = true // 无法刷新,标记失败 - } else { - tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account) - if err != nil { - // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token - slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err) - refreshFailed = true // 刷新失败,标记以使用短 TTL - } else { - // 构建新 credentials,保留原有字段 - newCredentials := make(map[string]any) - for k, v := range account.Credentials { - newCredentials[k] = v - } - newCredentials["access_token"] = tokenInfo.AccessToken - newCredentials["token_type"] = tokenInfo.TokenType - newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10) - newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) - if tokenInfo.RefreshToken != "" { - newCredentials["refresh_token"] = tokenInfo.RefreshToken - } - if tokenInfo.Scope != "" { - newCredentials["scope"] = tokenInfo.Scope - } - account.Credentials = newCredentials - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr) - } - expiresAt = account.GetCredentialAsTime("expires_at") - } - } - } - } else if lockErr != nil { - // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时) - slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr) - - // 检查 ctx 是否已取消 - if ctx.Err() != nil { - return "", ctx.Err() - } - - // 从数据库获取最新账户信息 - if p.accountRepo != nil { - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh - } - } - expiresAt = account.GetCredentialAsTime("expires_at") - - // 仅在 expires_at 已过期/接近过期时才执行无锁刷新 - if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew { - if p.oauthService == nil { - slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID) - refreshFailed = true - } else { - tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account) - if err != nil { - slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err) - refreshFailed = true - } else { - // 构建新 credentials,保留原有字段 - newCredentials := make(map[string]any) - for k, v := range account.Credentials { - newCredentials[k] = v - } - newCredentials["access_token"] = tokenInfo.AccessToken - newCredentials["token_type"] = tokenInfo.TokenType - newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10) - newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) - if tokenInfo.RefreshToken != "" { - newCredentials["refresh_token"] = tokenInfo.RefreshToken - } - if tokenInfo.Scope != "" { - newCredentials["scope"] = tokenInfo.Scope - } - account.Credentials = newCredentials - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr) - } - expiresAt = account.GetCredentialAsTime("expires_at") - } + slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err) + refreshFailed = true + } else if result.LockHeld { + if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil { + time.Sleep(claudeLockWaitTime) + if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" { + slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID) + return token, nil } } } else { - // 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存 + account = result.Account + expiresAt = account.GetCredentialAsTime("expires_at") + } + } else if needsRefresh && p.tokenCache != nil { + // Backward-compatible test path when refreshAPI is not injected. + locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if lockErr == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + } else if lockErr != nil { + slog.Warn("claude_token_lock_failed", "account_id", account.ID, "error", lockErr) + } else { time.Sleep(claudeLockWaitTime) if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID) @@ -181,22 +120,23 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou return "", errors.New("access_token not found in credentials") } - // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) + // 3) Populate cache with TTL. if p.tokenCache != nil { latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) if isStale && latestAccount != nil { - // 版本过时,使用 DB 中的最新 token slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID) accessToken = latestAccount.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found after version check") } - // 不写入缓存,让下次请求重新处理 } else { ttl := 30 * time.Minute if refreshFailed { - // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动 - ttl = time.Minute + if p.refreshPolicy.FailureTTL > 0 { + ttl = p.refreshPolicy.FailureTTL + } else { + ttl = time.Minute + } slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed") } else if expiresAt != nil { until := time.Until(*expiresAt) diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index 4dcf84e0..217b83d6 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -43,6 +43,9 @@ type ConcurrencyCache interface { // 清理过期槽位(后台任务) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error + + // 启动时清理旧进程遗留槽位与等待计数 + CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error } var ( @@ -59,13 +62,22 @@ func initRequestIDPrefix() string { return "r" + strconv.FormatUint(fallback, 36) } -// generateRequestID generates a unique request ID for concurrency slot tracking. -// Format: {process_random_prefix}-{base36_counter} +func RequestIDPrefix() string { + return requestIDPrefix +} + func generateRequestID() string { seq := requestIDCounter.Add(1) return requestIDPrefix + "-" + strconv.FormatUint(seq, 36) } +func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error { + if s == nil || s.cache == nil { + return nil + } + return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix()) +} + const ( // Default extra wait slots beyond concurrency limit defaultExtraWaitSlots = 20 diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go index 9ba43d93..078ba0dc 100644 --- a/backend/internal/service/concurrency_service_test.go +++ b/backend/internal/service/concurrency_service_test.go @@ -91,6 +91,32 @@ func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Conte return c.cleanupErr } +func (c *stubConcurrencyCacheForTest) CleanupStaleProcessSlots(_ context.Context, _ string) error { + return c.cleanupErr +} + +type trackingConcurrencyCache struct { + stubConcurrencyCacheForTest + cleanupPrefix string +} + +func (c *trackingConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, prefix string) error { + c.cleanupPrefix = prefix + return c.cleanupErr +} + +func TestCleanupStaleProcessSlots_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + require.NoError(t, svc.CleanupStaleProcessSlots(context.Background())) +} + +func TestCleanupStaleProcessSlots_DelegatesPrefix(t *testing.T) { + cache := &trackingConcurrencyCache{} + svc := NewConcurrencyService(cache) + require.NoError(t, svc.CleanupStaleProcessSlots(context.Background())) + require.Equal(t, RequestIDPrefix(), cache.cleanupPrefix) +} + func TestAcquireAccountSlot_Success(t *testing.T) { cache := &stubConcurrencyCacheForTest{acquireResult: true} svc := NewConcurrencyService(cache) diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index 040b2357..6a916740 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -221,7 +221,7 @@ func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts, }) if err != nil { - client = &http.Client{Timeout: 20 * time.Second} + return nil, fmt.Errorf("create http client failed: %w", err) } adminToken, err := crsLogin(ctx, client, normalizedURL, username, password) diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go index a67f8532..b58a1ea9 100644 --- a/backend/internal/service/dashboard_aggregation_service.go +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -35,6 +35,7 @@ type DashboardAggregationRepository interface { UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error CleanupUsageLogs(ctx context.Context, cutoff time.Time) error + CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error } @@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays) dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays) usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays) + dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays) aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff) if aggErr != nil { @@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, if usageErr != nil { logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr) } - if aggErr == nil && usageErr == nil { + dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff) + if dedupErr != nil { + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr) + } + if aggErr == nil && usageErr == nil && dedupErr == nil { s.lastRetentionCleanup.Store(now) } } diff --git a/backend/internal/service/dashboard_aggregation_service_test.go b/backend/internal/service/dashboard_aggregation_service_test.go index a7058985..fbb671bb 100644 --- a/backend/internal/service/dashboard_aggregation_service_test.go +++ b/backend/internal/service/dashboard_aggregation_service_test.go @@ -12,12 +12,18 @@ import ( type dashboardAggregationRepoTestStub struct { aggregateCalls int + recomputeCalls int + cleanupUsageCalls int + cleanupDedupCalls int + ensurePartitionCalls int lastStart time.Time lastEnd time.Time watermark time.Time aggregateErr error cleanupAggregatesErr error cleanupUsageErr error + cleanupDedupErr error + ensurePartitionErr error } func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error { @@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s } func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + s.recomputeCalls++ return s.AggregateRange(ctx, start, end) } @@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context } func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error { + s.cleanupUsageCalls++ return s.cleanupUsageErr } +func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + s.cleanupDedupCalls++ + return s.cleanupDedupErr +} + func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { - return nil + s.ensurePartitionCalls++ + return s.ensurePartitionErr } func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) { @@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te svc.maybeCleanupRetention(context.Background(), time.Now().UTC()) require.Nil(t, svc.lastRetentionCleanup.Load()) + require.Equal(t, 1, repo.cleanupUsageCalls) + require.Equal(t, 1, repo.cleanupDedupCalls) +} + +func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 1, + HourlyDays: 1, + DailyDays: 1, + }, + }, + } + + svc.maybeCleanupRetention(context.Background(), time.Now().UTC()) + + require.Nil(t, svc.lastRetentionCleanup.Load()) + require.Equal(t, 1, repo.cleanupDedupCalls) +} + +func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + Enabled: true, + IntervalSeconds: 60, + LookbackSeconds: 120, + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 1, + UsageBillingDedupDays: 2, + HourlyDays: 1, + DailyDays: 1, + }, + }, + } + + svc.runScheduledAggregation() + + require.Equal(t, 1, repo.ensurePartitionCalls) + require.Equal(t, 1, repo.aggregateCalls) } func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) { diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index 2af43386..63cad243 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -327,6 +327,14 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end return trend, nil } +func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) { + ranking, err := s.usageRepo.GetUserSpendingRanking(ctx, startTime, endTime, limit) + if err != nil { + return nil, fmt.Errorf("get user spending ranking: %w", err) + } + return ranking, nil +} + func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime) if err != nil { diff --git a/backend/internal/service/dashboard_service_test.go b/backend/internal/service/dashboard_service_test.go index 59b83e66..2a7f47b6 100644 --- a/backend/internal/service/dashboard_service_test.go +++ b/backend/internal/service/dashboard_service_test.go @@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut return nil } +func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + return nil +} + func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { return nil } diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index b304bc9f..2d8681d4 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -33,6 +33,7 @@ const ( AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) ) // Redeem type constants @@ -74,11 +75,13 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" // Setting keys const ( // 注册设置 - SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 - SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 - SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 - SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) - SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 + SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 + SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 + SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单(JSON 数组) + SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 + SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) + SettingKeyFrontendURL = "frontend_url" // 前端基础URL,用于生成邮件中的重置密码链接 + SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 // 邮件服务设置 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 @@ -113,8 +116,9 @@ const ( SettingKeyDocURL = "doc_url" // 文档链接 SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src) SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮 - SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示“购买订阅”页面入口 - SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src) + SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口 + SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src) + SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组) // 默认配置 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 @@ -173,6 +177,20 @@ const ( // SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling. SettingKeyStreamTimeoutSettings = "stream_timeout_settings" + // ========================= + // Request Rectifier (请求整流器) + // ========================= + + // SettingKeyRectifierSettings stores JSON config for rectifier settings (thinking signature + budget). + SettingKeyRectifierSettings = "rectifier_settings" + + // ========================= + // Beta Policy Settings + // ========================= + + // SettingKeyBetaPolicySettings stores JSON config for beta policy rules. + SettingKeyBetaPolicySettings = "beta_policy_settings" + // ========================= // Sora S3 存储配置 // ========================= @@ -200,6 +218,12 @@ const ( // SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查) SettingKeyMinClaudeCodeVersion = "min_claude_code_version" + + // SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403) + SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling" + + // SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录 + SettingKeyBackendModeEnabled = "backend_mode_enabled" ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/error_policy_test.go b/backend/internal/service/error_policy_test.go index 9d7d025e..297a954c 100644 --- a/backend/internal/service/error_policy_test.go +++ b/backend/internal/service/error_policy_test.go @@ -88,6 +88,51 @@ func TestCheckErrorPolicy(t *testing.T) { body: []byte(`overloaded service`), expected: ErrorPolicyTempUnscheduled, }, + { + name: "temp_unschedulable_401_first_hit_returns_temp_unscheduled", + account: &Account{ + ID: 14, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 401, + body: []byte(`unauthorized`), + expected: ErrorPolicyTempUnscheduled, + }, + { + // Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制), + // second hit 仍然返回 TempUnscheduled。 + name: "temp_unschedulable_401_second_hit_antigravity_stays_temp", + account: &Account{ + ID: 15, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 401, + body: []byte(`unauthorized`), + expected: ErrorPolicyTempUnscheduled, + }, { name: "temp_unschedulable_body_miss_returns_none", account: &Account{ @@ -134,6 +179,36 @@ func TestCheckErrorPolicy(t *testing.T) { body: []byte(`overloaded`), expected: ErrorPolicyMatched, // custom codes take precedence }, + { + name: "pool_mode_custom_error_codes_hit_returns_matched", + account: &Account{ + ID: 7, + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(401), float64(403)}, + }, + }, + statusCode: 401, + body: []byte(`unauthorized`), + expected: ErrorPolicyMatched, + }, + { + name: "pool_mode_without_custom_error_codes_returns_skipped", + account: &Account{ + ID: 8, + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + }, + }, + statusCode: 401, + body: []byte(`unauthorized`), + expected: ErrorPolicySkipped, + }, } for _, tt := range tests { @@ -147,6 +222,48 @@ func TestCheckErrorPolicy(t *testing.T) { } } +func TestHandleUpstreamError_PoolModeCustomErrorCodesOverride(t *testing.T) { + t.Run("pool_mode_without_custom_error_codes_still_skips", func(t *testing.T) { + repo := &errorPolicyRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 30, + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + }, + } + + shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.False(t, shouldDisable) + require.Equal(t, 0, repo.setErrCalls) + require.Equal(t, 0, repo.tempCalls) + }) + + t.Run("pool_mode_with_custom_error_codes_uses_local_error_policy", func(t *testing.T) { + repo := &errorPolicyRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 31, + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(401)}, + }, + } + + shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrCalls) + require.Equal(t, 0, repo.tempCalls) + }) +} + // --------------------------------------------------------------------------- // TestApplyErrorPolicy — 4 table-driven cases for the wrapper method // --------------------------------------------------------------------------- diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index f8c0ecda..789cbab8 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd }, } - svc := &GatewayService{ - cfg: &config.Config{ - Gateway: config.GatewayConfig{ - MaxLineSize: defaultMaxLineSize, - }, + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, }, - httpUpstream: upstream, - rateLimitService: &RateLimitService{}, - deferredService: &DeferredService{}, - billingCacheService: nil, + } + svc := &GatewayService{ + cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + deferredService: &DeferredService{}, + billingCacheService: nil, } account := &Account{ @@ -171,8 +173,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd require.NotNil(t, result) require.True(t, result.Stream) - require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体") - require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射") require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) require.Empty(t, upstream.lastReq.Header.Get("authorization")) @@ -190,7 +191,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd require.True(t, ok) bodyBytes, ok := rawBody.([]byte) require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝") - require.Equal(t, body, bodyBytes) + require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(bodyBytes, "model").String(), "缓存的上游请求体应包含映射后的模型") } func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) { @@ -222,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo }, } - svc := &GatewayService{ - cfg: &config.Config{ - Gateway: config.GatewayConfig{ - MaxLineSize: defaultMaxLineSize, - }, + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, }, - httpUpstream: upstream, - rateLimitService: &RateLimitService{}, + } + svc := &GatewayService{ + cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, } account := &Account{ @@ -253,8 +256,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo err := svc.ForwardCountTokens(context.Background(), c, account, parsed) require.NoError(t, err) - require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体") - require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射") require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) require.Empty(t, upstream.lastReq.Header.Get("authorization")) require.Empty(t, upstream.lastReq.Header.Get("cookie")) @@ -263,6 +265,273 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo require.Empty(t, rec.Header().Get("Set-Cookie")) } +// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases 覆盖透传模式下模型映射的各种边界情况 +func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + model string + modelMapping map[string]any // nil = 不配置映射 + expectedModel string + endpoint string // "messages" or "count_tokens" + }{ + { + name: "Forward: 无映射配置时不改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: nil, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "messages", + }, + { + name: "Forward: 空映射配置时不改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{}, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "messages", + }, + { + name: "Forward: 模型不在映射表中时不改写", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"}, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "messages", + }, + { + name: "Forward: 精确匹配映射应改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"}, + expectedModel: "claude-sonnet-4-5-20241022", + endpoint: "messages", + }, + { + name: "Forward: 通配符映射应改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"}, + expectedModel: "claude-sonnet-4-5-20241022", + endpoint: "messages", + }, + { + name: "CountTokens: 无映射配置时不改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: nil, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "count_tokens", + }, + { + name: "CountTokens: 模型不在映射表中时不改写", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"}, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "count_tokens", + }, + { + name: "CountTokens: 精确匹配映射应改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"}, + expectedModel: "claude-sonnet-4-5-20241022", + endpoint: "count_tokens", + }, + { + name: "CountTokens: 通配符映射应改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"}, + expectedModel: "claude-sonnet-4-5-20241022", + endpoint: "count_tokens", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + body := []byte(`{"model":"` + tt.model + `","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + parsed := &ParsedRequest{ + Body: body, + Model: tt.model, + } + + credentials := map[string]any{ + "api_key": "upstream-key", + "base_url": "https://api.anthropic.com", + } + if tt.modelMapping != nil { + credentials["model_mapping"] = tt.modelMapping + } + + account := &Account{ + ID: 300, + Name: "edge-case-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: credentials, + Extra: map[string]any{"anthropic_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + if tt.endpoint == "messages" { + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + parsed.Stream = false + + upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":5,"output_tokens":3}}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamJSON)), + }, + } + svc := &GatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + result, err := svc.Forward(context.Background(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(), + "Forward 上游请求体中的模型应为: %s", tt.expectedModel) + } else { + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + + upstreamRespBody := `{"input_tokens":42}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + svc := &GatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(), + "CountTokens 上游请求体中的模型应为: %s", tt.expectedModel) + } + }) + } +} + +// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields +// 确保模型映射只替换 model 字段,不影响请求体中的其他字段 +func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + + // 包含复杂字段的请求体:system、thinking、messages + body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[{"role":"user","content":[{"type":"text","text":"hello world"}]}],"thinking":{"type":"enabled","budget_tokens":5000},"max_tokens":1024}`) + parsed := &ParsedRequest{ + Body: body, + Model: "claude-sonnet-4-20250514", + } + + upstreamRespBody := `{"input_tokens":42}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 301, + Name: "preserve-fields-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-key", + "base_url": "https://api.anthropic.com", + "model_mapping": map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"}, + }, + Extra: map[string]any{"anthropic_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + + sentBody := upstream.lastBody + require.Equal(t, "claude-sonnet-4-5-20241022", gjson.GetBytes(sentBody, "model").String(), "model 应被映射") + require.Equal(t, "You are a helpful assistant.", gjson.GetBytes(sentBody, "system.0.text").String(), "system 字段不应被修改") + require.Equal(t, "hello world", gjson.GetBytes(sentBody, "messages.0.content.0.text").String(), "messages 字段不应被修改") + require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String(), "thinking 字段不应被修改") + require.Equal(t, int64(5000), gjson.GetBytes(sentBody, "thinking.budget_tokens").Int(), "thinking.budget_tokens 不应被修改") + require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改") +} + +// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping +// 确保空模型名不会触发映射逻辑 +func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + + body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`) + parsed := &ParsedRequest{ + Body: body, + Model: "", // 空模型 + } + + upstreamRespBody := `{"input_tokens":10}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 302, + Name: "empty-model-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-key", + "base_url": "https://api.anthropic.com", + "model_mapping": map[string]any{"*": "claude-3-opus-20240229"}, + }, + Extra: map[string]any{"anthropic_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + // 空模型名时,body 应原样透传,不应触发映射 + require.Equal(t, body, upstream.lastBody, "空模型名时请求体不应被修改") +} + func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) { gin.SetMode(gin.TestMode) @@ -462,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf require.Equal(t, 5, result.usage.OutputTokens) } +func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`, + "", + `data: {"type":"message_delta","usage":{"output_tokens":5}}`, + "", + }, "\n"))), + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219") + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") + require.NotNil(t, result) +} + func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() @@ -809,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi _ = pr.Close() <-done - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "stream usage incomplete after timeout") require.NotNil(t, result) require.True(t, result.clientDisconnect) require.Equal(t, 9, result.usage.InputTokens) @@ -838,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t } result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219") - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "stream usage incomplete") require.NotNil(t, result) require.True(t, result.clientDisconnect) } @@ -868,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft } result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219") - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "stream usage incomplete after disconnect") require.NotNil(t, result) require.True(t, result.clientDisconnect) require.Equal(t, 8, result.usage.InputTokens) diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go index 21a1faa4..ecaffe21 100644 --- a/backend/internal/service/gateway_beta_test.go +++ b/backend/internal/service/gateway_beta_test.go @@ -86,10 +86,10 @@ func TestStripBetaTokens(t *testing.T) { want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", }, { - name: "DroppedBetas removes both context-1m and fast-mode", + name: "DroppedBetas is empty (filtering moved to configurable beta policy)", header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", tokens: claude.DroppedBetas, - want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + want: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", }, } @@ -114,25 +114,23 @@ func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) { func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) { required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"} incoming := "context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta,oauth-2025-04-20" + // DroppedBetas is now empty — filtering moved to configurable beta policy. + // Without a policy filter set, nothing gets dropped from the static set. drop := droppedBetaSet() got := mergeAnthropicBetaDropping(required, incoming, drop) - require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got) - require.NotContains(t, got, "context-1m-2025-08-07") - require.NotContains(t, got, "fast-mode-2026-02-01") + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta", got) + require.Contains(t, got, "context-1m-2025-08-07") + require.Contains(t, got, "fast-mode-2026-02-01") } func TestDroppedBetaSet(t *testing.T) { - // Base set contains DroppedBetas + // Base set contains DroppedBetas (now empty — filtering moved to configurable beta policy) base := droppedBetaSet() - require.Contains(t, base, claude.BetaContext1M) - require.Contains(t, base, claude.BetaFastMode) require.Len(t, base, len(claude.DroppedBetas)) // With extra tokens extended := droppedBetaSet(claude.BetaClaudeCode) - require.Contains(t, extended, claude.BetaContext1M) - require.Contains(t, extended, claude.BetaFastMode) require.Contains(t, extended, claude.BetaClaudeCode) require.Len(t, extended, len(claude.DroppedBetas)+1) } @@ -148,6 +146,32 @@ func TestBuildBetaTokenSet(t *testing.T) { require.Empty(t, empty) } +func TestContainsBetaToken(t *testing.T) { + tests := []struct { + name string + header string + token string + want bool + }{ + {"present in middle", "oauth-2025-04-20,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", true}, + {"present at start", "fast-mode-2026-02-01,oauth-2025-04-20", "fast-mode-2026-02-01", true}, + {"present at end", "oauth-2025-04-20,fast-mode-2026-02-01", "fast-mode-2026-02-01", true}, + {"only token", "fast-mode-2026-02-01", "fast-mode-2026-02-01", true}, + {"not present", "oauth-2025-04-20,interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", false}, + {"with spaces", "oauth-2025-04-20, fast-mode-2026-02-01 , interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", true}, + {"empty header", "", "fast-mode-2026-02-01", false}, + {"empty token", "fast-mode-2026-02-01", "", false}, + {"partial match", "fast-mode-2026-02-01-extra", "fast-mode-2026-02-01", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := containsBetaToken(tt.header, tt.token) + require.Equal(t, tt.want, got) + }) + } +} + func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) { header := "oauth-2025-04-20,interleaved-thinking-2025-05-14" got := stripBetaTokensWithSet(header, map[string]struct{}{}) diff --git a/backend/internal/service/gateway_group_isolation_test.go b/backend/internal/service/gateway_group_isolation_test.go new file mode 100644 index 00000000..00508f0e --- /dev/null +++ b/backend/internal/service/gateway_group_isolation_test.go @@ -0,0 +1,363 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// ============================================================================ +// Part 1: isAccountInGroup 单元测试 +// ============================================================================ + +func TestIsAccountInGroup(t *testing.T) { + svc := &GatewayService{} + groupID100 := int64(100) + groupID200 := int64(200) + + tests := []struct { + name string + account *Account + groupID *int64 + expected bool + }{ + // groupID == nil(无分组 API Key) + { + "nil_groupID_ungrouped_account_nil_groups", + &Account{ID: 1, AccountGroups: nil}, + nil, true, + }, + { + "nil_groupID_ungrouped_account_empty_slice", + &Account{ID: 2, AccountGroups: []AccountGroup{}}, + nil, true, + }, + { + "nil_groupID_grouped_account_single", + &Account{ID: 3, AccountGroups: []AccountGroup{{GroupID: 100}}}, + nil, false, + }, + { + "nil_groupID_grouped_account_multiple", + &Account{ID: 4, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}}, + nil, false, + }, + // groupID != nil(有分组 API Key) + { + "with_groupID_account_in_group", + &Account{ID: 5, AccountGroups: []AccountGroup{{GroupID: 100}}}, + &groupID100, true, + }, + { + "with_groupID_account_not_in_group", + &Account{ID: 6, AccountGroups: []AccountGroup{{GroupID: 200}}}, + &groupID100, false, + }, + { + "with_groupID_ungrouped_account", + &Account{ID: 7, AccountGroups: nil}, + &groupID100, false, + }, + { + "with_groupID_multi_group_account_match_one", + &Account{ID: 8, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}}, + &groupID200, true, + }, + { + "with_groupID_multi_group_account_no_match", + &Account{ID: 9, AccountGroups: []AccountGroup{{GroupID: 300}, {GroupID: 400}}}, + &groupID100, false, + }, + // 防御性边界 + { + "nil_account_nil_groupID", + nil, + nil, false, + }, + { + "nil_account_with_groupID", + nil, + &groupID100, false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.isAccountInGroup(tt.account, tt.groupID) + require.Equal(t, tt.expected, got, "isAccountInGroup 结果不符预期") + }) + } +} + +// ============================================================================ +// Part 2: 分组隔离端到端调度测试 +// ============================================================================ + +// groupAwareMockAccountRepo 嵌入 mockAccountRepoForPlatform,覆写分组隔离相关方法。 +// allAccounts 存储所有账号,分组查询方法按 AccountGroups 字段进行真实过滤。 +type groupAwareMockAccountRepo struct { + *mockAccountRepoForPlatform + allAccounts []Account +} + +// ListSchedulableUngroupedByPlatform 仅返回未分组账号(AccountGroups 为空) +func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + var result []Account + for _, acc := range m.allAccounts { + if acc.Platform == platform && acc.IsSchedulable() && len(acc.AccountGroups) == 0 { + result = append(result, acc) + } + } + return result, nil +} + +// ListSchedulableUngroupedByPlatforms 仅返回未分组账号(多平台版本) +func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + platformSet := make(map[string]bool, len(platforms)) + for _, p := range platforms { + platformSet[p] = true + } + var result []Account + for _, acc := range m.allAccounts { + if platformSet[acc.Platform] && acc.IsSchedulable() && len(acc.AccountGroups) == 0 { + result = append(result, acc) + } + } + return result, nil +} + +// ListSchedulableByGroupIDAndPlatform 返回属于指定分组的账号 +func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { + var result []Account + for _, acc := range m.allAccounts { + if acc.Platform == platform && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) { + result = append(result, acc) + } + } + return result, nil +} + +// ListSchedulableByGroupIDAndPlatforms 返回属于指定分组的账号(多平台版本) +func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { + platformSet := make(map[string]bool, len(platforms)) + for _, p := range platforms { + platformSet[p] = true + } + var result []Account + for _, acc := range m.allAccounts { + if platformSet[acc.Platform] && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) { + result = append(result, acc) + } + } + return result, nil +} + +// accountBelongsToGroup 检查账号是否属于指定分组 +func accountBelongsToGroup(acc Account, groupID int64) bool { + for _, ag := range acc.AccountGroups { + if ag.GroupID == groupID { + return true + } + } + return false +} + +// Verify interface implementation +var _ AccountRepository = (*groupAwareMockAccountRepo)(nil) + +// newGroupAwareMockRepo 创建分组感知的 mock repo +func newGroupAwareMockRepo(accounts []Account) *groupAwareMockAccountRepo { + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + return &groupAwareMockAccountRepo{ + mockAccountRepoForPlatform: &mockAccountRepoForPlatform{ + accounts: accounts, + accountsByID: byID, + }, + allAccounts: accounts, + } +} + +func TestGroupIsolation_UngroupedKey_ShouldNotScheduleGroupedAccounts(t *testing.T) { + // 场景:无分组 API Key(groupID=nil),池中只有已分组账号 → 应返回错误 + ctx := context.Background() + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 200}}}, + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.Error(t, err, "无分组 Key 不应调度到已分组账号") + require.Nil(t, acc) +} + +func TestGroupIsolation_GroupedKey_ShouldNotScheduleUngroupedAccounts(t *testing.T) { + // 场景:有分组 API Key(groupID=100),池中只有未分组账号 → 应返回错误 + ctx := context.Background() + groupID := int64(100) + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{}}, + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI) + require.Error(t, err, "有分组 Key 不应调度到未分组账号") + require.Nil(t, acc) +} + +func TestGroupIsolation_UngroupedKey_ShouldOnlyScheduleUngroupedAccounts(t *testing.T) { + // 场景:无分组 API Key(groupID=nil),池中有未分组和已分组账号 → 应只选中未分组的 + ctx := context.Background() + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组,不应被选中 + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, // 未分组,应被选中 + {ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 200}}}, // 已分组,不应被选中 + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "应成功调度未分组账号") + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "应选中未分组的账号 ID=2") +} + +func TestGroupIsolation_GroupedKey_ShouldOnlyScheduleMatchingGroupAccounts(t *testing.T) { + // 场景:有分组 API Key(groupID=100),池中有未分组和多个分组账号 → 应只选中分组 100 内的 + ctx := context.Background() + groupID := int64(100) + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, // 未分组,不应被选中 + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 200}}}, // 属于分组 200,不应被选中 + {ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, // 属于分组 100,应被选中 + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "应成功调度分组内账号") + require.NotNil(t, acc) + require.Equal(t, int64(3), acc.ID, "应选中分组 100 内的账号 ID=3") +} + +// ============================================================================ +// Part 3: SimpleMode 旁路测试 +// ============================================================================ + +func TestGroupIsolation_SimpleMode_SkipsGroupIsolation(t *testing.T) { + // SimpleMode 应跳过分组隔离,使用 ListSchedulableByPlatform 返回所有账号。 + // 测试非 useMixed 路径(platform=openai,不会触发 mixed 调度逻辑)。 + ctx := context.Background() + + // 混合未分组和已分组账号,SimpleMode 下应全部可调度 + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组 + {ID: 2, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, // 未分组 + } + + // 使用基础 mock(ListSchedulableByPlatform 返回所有匹配平台的账号,不做分组过滤) + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + repo := &mockAccountRepoForPlatform{ + accounts: accounts, + accountsByID: byID, + } + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: &config.Config{RunMode: config.RunModeSimple}, + } + + // groupID=nil 时,SimpleMode 应使用 ListSchedulableByPlatform(不过滤分组) + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "SimpleMode 应跳过分组隔离直接返回账号") + require.NotNil(t, acc) + // 应选择优先级最高的账号(Priority=1, ID=2),即使它未分组 + require.Equal(t, int64(2), acc.ID, "SimpleMode 应按优先级选择,不考虑分组") +} + +func TestGroupIsolation_SimpleMode_GroupedAccountAlsoSchedulable(t *testing.T) { + // SimpleMode + groupID=nil 时,已分组账号也应该可被调度 + ctx := context.Background() + + // 只有已分组账号,在 standard 模式下 groupID=nil 会报错,但 simple 模式应正常 + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, + } + + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + repo := &mockAccountRepoForPlatform{ + accounts: accounts, + accountsByID: byID, + } + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: &config.Config{RunMode: config.RunModeSimple}, + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "SimpleMode 下已分组账号也应可调度") + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID, "SimpleMode 应能调度已分组账号") +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 067a0e08..ea8fa784 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -147,6 +147,12 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Cont func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { return m.ListSchedulableByPlatforms(ctx, platforms) } +func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return m.ListSchedulableByPlatform(ctx, platform) +} +func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + return m.ListSchedulableByPlatforms(ctx, platforms) +} func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } @@ -181,6 +187,14 @@ func (m *mockAccountRepoForPlatform) BulkUpdate(ctx context.Context, ids []int64 return 0, nil } +func (m *mockAccountRepoForPlatform) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return nil +} + +func (m *mockAccountRepoForPlatform) ResetQuotaUsed(ctx context.Context, id int64) error { + return nil +} + // Verify interface implementation var _ AccountRepository = (*mockAccountRepoForPlatform)(nil) @@ -426,7 +440,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) require.Error(t, err) require.Nil(t, acc) - require.Contains(t, err.Error(), "no available accounts") + require.ErrorIs(t, err, ErrNoAvailableAccounts) } // TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded 测试所有账户被排除 @@ -1059,7 +1073,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAccounts(t *testing. acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformAnthropic) require.Error(t, err) require.Nil(t, acc) - require.Contains(t, err.Error(), "no available accounts") + require.ErrorIs(t, err, ErrNoAvailableAccounts) } func TestGatewayService_isModelSupportedByAccount(t *testing.T) { @@ -1720,7 +1734,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) require.Error(t, err) require.Nil(t, acc) - require.Contains(t, err.Error(), "no available accounts") + require.ErrorIs(t, err, ErrNoAvailableAccounts) }) t.Run("混合调度-不支持模型返回错误", func(t *testing.T) { @@ -1972,6 +1986,10 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a return nil } +func (m *mockConcurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { + return nil +} + func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { result := make(map[int64]*UserLoadInfo, len(users)) for _, user := range users { @@ -2272,7 +2290,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") require.Error(t, err) require.Nil(t, result) - require.Contains(t, err.Error(), "no available accounts") + require.ErrorIs(t, err, ErrNoAvailableAccounts) }) t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) { diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go new file mode 100644 index 00000000..4c1f0317 --- /dev/null +++ b/backend/internal/service/gateway_record_usage_test.go @@ -0,0 +1,422 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 1.1 + return NewGatewayService( + nil, + nil, + usageRepo, + nil, + userRepo, + subRepo, + nil, + nil, + cfg, + nil, + nil, + NewBillingService(cfg, nil), + nil, + &BillingCacheService{}, + nil, + nil, + &DeferredService{}, + nil, + nil, + nil, + nil, + nil, + ) +} + +func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService { + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + svc.usageBillingRepo = billingRepo + return svc +} + +type openAIRecordUsageBestEffortLogRepoStub struct { + UsageLogRepository + + bestEffortErr error + createErr error + bestEffortCalls int + createCalls int + lastLog *UsageLog + lastCtxErr error +} + +func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error { + s.bestEffortCalls++ + s.lastLog = log + s.lastCtxErr = ctx.Err() + return s.bestEffortErr +} + +func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { + s.createCalls++ + s.lastLog = log + s.lastCtxErr = ctx.Err() + return false, s.createErr +} + +func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsage(reqCtx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_detached_ctx", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 501, + Quota: 100, + }, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.NoError(t, userRepo.lastCtxErr) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.NoError(t, quotaSvc.lastQuotaCtxErr) +} + +func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`)) + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_payload_hash", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + RequestPayloadHash: payloadHash, + }) + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash) +} + +func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123") + err := svc.RecordUsage(ctx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_payload_fallback", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + }) + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash) +} + +func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_not_persisted", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 503, + Quota: 100, + }, + User: &User{ID: 603}, + Account: &Account{ID: 703}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 1, quotaSvc.quotaCalls) +} + +func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{ + Result: &ForwardResult{ + RequestID: "gateway_long_context_detached_ctx", + Usage: ClaudeUsage{ + InputTokens: 12, + OutputTokens: 8, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 502, + Quota: 100, + }, + User: &User{ID: 602}, + Account: &Account{ID: 702}, + LongContextThreshold: 200000, + LongContextMultiplier: 2, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.NoError(t, userRepo.lastCtxErr) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.NoError(t, quotaSvc.lastQuotaCtxErr) +} + +func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback") + err := svc.RecordUsage(ctx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 504}, + User: &User{ID: 604}, + Account: &Account{ID: 704}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID) +} + +func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123") + ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored") + err := svc.RecordUsage(ctx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "upstream-volatile-456", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 506}, + User: &User{ID: 606}, + Account: &Account{ID: 706}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID) +} + +func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 507}, + User: &User{ID: 607}, + Account: &Account{ID: 707}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:")) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID) +} + +func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) { + usageRepo := &openAIRecordUsageBestEffortLogRepoStub{ + bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")), + } + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_drop_usage_log", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 508}, + User: &User{ID: 608}, + Account: &Account{ID: 708}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.bestEffortCalls) + require.Equal(t, 0, usageRepo.createCalls) +} + +func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_billing_fail", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 505}, + User: &User{ID: 605}, + Account: &Account{ID: 705}, + }) + + require.Error(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 0, usageRepo.calls) +} + +func TestGatewayServiceRecordUsage_ReasoningEffortPersisted(t *testing.T) { + usageRepo := &openAIRecordUsageBestEffortLogRepoStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + effort := "max" + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "effort_test", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + Model: "claude-opus-4-6", + Duration: time.Second, + ReasoningEffort: &effort, + }, + APIKey: &APIKey{ID: 1}, + User: &User{ID: 1}, + Account: &Account{ID: 1}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.NotNil(t, usageRepo.lastLog.ReasoningEffort) + require.Equal(t, "max", *usageRepo.lastLog.ReasoningEffort) +} + +func TestGatewayServiceRecordUsage_ReasoningEffortNil(t *testing.T) { + usageRepo := &openAIRecordUsageBestEffortLogRepoStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "no_effort_test", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1}, + User: &User{ID: 1}, + Account: &Account{ID: 1}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Nil(t, usageRepo.lastLog.ReasoningEffort) +} diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index f8096a0e..3816aea9 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "math" + "strings" "unsafe" "github.com/Wei-Shaw/sub2api/internal/domain" @@ -59,8 +60,13 @@ type ParsedRequest struct { Messages []any // messages 数组 HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) + OutputEffort string // output_config.effort(Claude API 的推理强度控制) MaxTokens int // max_tokens 值(用于探测请求拦截) SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变) + + // OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁) + // 流式请求在收到 2xx 响应头后调用,避免持锁等流完成 + OnUpstreamAccepted func() } // ParseGatewayRequest 解析网关请求体并返回结构化结果。 @@ -111,6 +117,9 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { parsed.ThinkingEnabled = true } + // output_config.effort: Claude API 的推理强度控制参数 + parsed.OutputEffort = strings.TrimSpace(gjson.Get(jsonStr, "output_config.effort").String()) + // max_tokens: 仅接受整数值 maxTokensResult := gjson.Get(jsonStr, "max_tokens") if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number { @@ -254,6 +263,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { if !hasEmptyContent && !containsThinkingBlocks { if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() { if out, err := sjson.DeleteBytes(body, "thinking"); err == nil { + out = removeThinkingDependentContextStrategies(out) return out } return body @@ -391,6 +401,10 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { } else { return body } + // Removing "thinking" makes any context_management strategy that requires it invalid + // (e.g. clear_thinking_20251015). Strip those entries so the retry request does not + // receive a 400 "strategy requires thinking to be enabled or adaptive". + out = removeThinkingDependentContextStrategies(out) } if modified { msgsBytes, err := json.Marshal(messages) @@ -405,6 +419,49 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { return out } +// removeThinkingDependentContextStrategies 从 context_management.edits 中移除 +// 需要 thinking 启用的策略(如 clear_thinking_20251015)。 +// 当顶层 "thinking" 字段被禁用时必须调用,否则上游会返回 +// "strategy requires thinking to be enabled or adaptive"。 +func removeThinkingDependentContextStrategies(body []byte) []byte { + jsonStr := *(*string)(unsafe.Pointer(&body)) + editsRes := gjson.Get(jsonStr, "context_management.edits") + if !editsRes.Exists() || !editsRes.IsArray() { + return body + } + + var filtered []json.RawMessage + hasRemoved := false + editsRes.ForEach(func(_, v gjson.Result) bool { + if v.Get("type").String() == "clear_thinking_20251015" { + hasRemoved = true + return true + } + filtered = append(filtered, json.RawMessage(v.Raw)) + return true + }) + + if !hasRemoved { + return body + } + + if len(filtered) == 0 { + if b, err := sjson.DeleteBytes(body, "context_management.edits"); err == nil { + return b + } + return body + } + + filteredBytes, err := json.Marshal(filtered) + if err != nil { + return body + } + if b, err := sjson.SetRawBytes(body, "context_management.edits", filteredBytes); err == nil { + return b + } + return body +} + // FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate // signature/thought_signature validation issues involving tool blocks. // @@ -440,6 +497,28 @@ func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte { if _, exists := req["thinking"]; exists { delete(req, "thinking") modified = true + // Remove context_management strategies that require thinking to be enabled + // (e.g. clear_thinking_20251015), otherwise upstream returns 400. + if cm, ok := req["context_management"].(map[string]any); ok { + if edits, ok := cm["edits"].([]any); ok { + filtered := make([]any, 0, len(edits)) + for _, edit := range edits { + if editMap, ok := edit.(map[string]any); ok { + if editMap["type"] == "clear_thinking_20251015" { + continue + } + } + filtered = append(filtered, edit) + } + if len(filtered) != len(edits) { + if len(filtered) == 0 { + delete(cm, "edits") + } else { + cm["edits"] = filtered + } + } + } + } } messages, ok := req["messages"].([]any) @@ -671,3 +750,105 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte { } return newBody } + +// NormalizeClaudeOutputEffort normalizes Claude's output_config.effort value. +// Returns nil for empty or unrecognized values. +func NormalizeClaudeOutputEffort(raw string) *string { + value := strings.ToLower(strings.TrimSpace(raw)) + if value == "" { + return nil + } + switch value { + case "low", "medium", "high", "max": + return &value + default: + return nil + } +} + +// ========================= +// Thinking Budget Rectifier +// ========================= + +const ( + // BudgetRectifyBudgetTokens is the budget_tokens value to set when rectifying. + BudgetRectifyBudgetTokens = 32000 + // BudgetRectifyMaxTokens is the max_tokens value to set when rectifying. + BudgetRectifyMaxTokens = 64000 + // BudgetRectifyMinMaxTokens is the minimum max_tokens that must exceed budget_tokens. + BudgetRectifyMinMaxTokens = 32001 +) + +// isThinkingBudgetConstraintError detects whether an upstream error message indicates +// a budget_tokens constraint violation (e.g. "budget_tokens >= 1024"). +// Matches three conditions (all must be true): +// 1. Contains "budget_tokens" or "budget tokens" +// 2. Contains "thinking" +// 3. Contains ">= 1024" or "greater than or equal to 1024" or ("1024" + "input should be") +func isThinkingBudgetConstraintError(errMsg string) bool { + m := strings.ToLower(errMsg) + + // Condition 1: budget_tokens or budget tokens + hasBudget := strings.Contains(m, "budget_tokens") || strings.Contains(m, "budget tokens") + if !hasBudget { + return false + } + + // Condition 2: thinking + if !strings.Contains(m, "thinking") { + return false + } + + // Condition 3: constraint indicator + if strings.Contains(m, ">= 1024") || strings.Contains(m, "greater than or equal to 1024") { + return true + } + if strings.Contains(m, "1024") && strings.Contains(m, "input should be") { + return true + } + + return false +} + +// RectifyThinkingBudget modifies the request body to fix budget_tokens constraint errors. +// It sets thinking.budget_tokens = 32000, thinking.type = "enabled" (unless adaptive), +// and ensures max_tokens >= 32001. +// Returns (modified body, true) if changes were applied, or (original body, false) if not. +func RectifyThinkingBudget(body []byte) ([]byte, bool) { + // If thinking type is "adaptive", skip rectification entirely + thinkingType := gjson.GetBytes(body, "thinking.type").String() + if thinkingType == "adaptive" { + return body, false + } + + modified := body + changed := false + + // Set thinking.type = "enabled" + if thinkingType != "enabled" { + if result, err := sjson.SetBytes(modified, "thinking.type", "enabled"); err == nil { + modified = result + changed = true + } + } + + // Set thinking.budget_tokens = 32000 + currentBudget := gjson.GetBytes(modified, "thinking.budget_tokens").Int() + if currentBudget != BudgetRectifyBudgetTokens { + if result, err := sjson.SetBytes(modified, "thinking.budget_tokens", BudgetRectifyBudgetTokens); err == nil { + modified = result + changed = true + } + } + + // Ensure max_tokens >= BudgetRectifyMinMaxTokens + maxTokens := gjson.GetBytes(modified, "max_tokens").Int() + if maxTokens < int64(BudgetRectifyMinMaxTokens) { + if result, err := sjson.SetBytes(modified, "max_tokens", BudgetRectifyMaxTokens); err == nil { + modified = result + changed = true + } + } + + return modified, changed +} diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index 2a9b4017..f60ed9fb 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -439,6 +439,210 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) { require.Contains(t, content1["text"], "tool_result") } +// ============ Group 6b: context_management.edits 清理测试 ============ + +// removeThinkingDependentContextStrategies — 边界用例 + +func TestRemoveThinkingDependentContextStrategies_NoContextManagement(t *testing.T) { + input := []byte(`{"thinking":{"type":"enabled"},"messages":[]}`) + out := removeThinkingDependentContextStrategies(input) + require.Equal(t, input, out, "无 context_management 字段时应原样返回") +} + +func TestRemoveThinkingDependentContextStrategies_EmptyEdits(t *testing.T) { + input := []byte(`{"context_management":{"edits":[]},"messages":[]}`) + out := removeThinkingDependentContextStrategies(input) + require.Equal(t, input, out, "edits 为空数组时应原样返回") +} + +func TestRemoveThinkingDependentContextStrategies_NoClearThinkingEntry(t *testing.T) { + input := []byte(`{"context_management":{"edits":[{"type":"other_strategy"}]},"messages":[]}`) + out := removeThinkingDependentContextStrategies(input) + require.Equal(t, input, out, "edits 中无 clear_thinking_20251015 时应原样返回") +} + +func TestRemoveThinkingDependentContextStrategies_RemovesSingleEntry(t *testing.T) { + input := []byte(`{"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`) + out := removeThinkingDependentContextStrategies(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + _, hasEdits := cm["edits"] + require.False(t, hasEdits, "所有 edits 均为 clear_thinking_20251015 时应删除 edits 键") +} + +func TestRemoveThinkingDependentContextStrategies_MixedEntries(t *testing.T) { + input := []byte(`{"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"other_strategy","param":1}]},"messages":[]}`) + out := removeThinkingDependentContextStrategies(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + edits, ok := cm["edits"].([]any) + require.True(t, ok) + require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留其他条目") + edit0, ok := edits[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "other_strategy", edit0["type"]) +} + +// FilterThinkingBlocksForRetry — 包含 context_management 的场景 + +func TestFilterThinkingBlocksForRetry_RemovesClearThinkingStrategy_FastPath(t *testing.T) { + // 快速路径:messages 中无 thinking 块,仅有顶层 thinking 字段 + // 这条路径曾因提前 return 跳过 removeThinkingDependentContextStrategies 而存在 bug + input := []byte(`{ + "thinking":{"type":"enabled","budget_tokens":1024}, + "context_management":{"edits":[{"type":"clear_thinking_20251015"}]}, + "messages":[ + {"role":"user","content":[{"type":"text","text":"Hello"}]} + ] + }`) + + out := FilterThinkingBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + _, hasThinking := req["thinking"] + require.False(t, hasThinking, "顶层 thinking 应被移除") + + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + _, hasEdits := cm["edits"] + require.False(t, hasEdits, "fast path 下 clear_thinking_20251015 应被移除,edits 键应被删除") +} + +func TestFilterThinkingBlocksForRetry_RemovesClearThinkingStrategy_WithThinkingBlocks(t *testing.T) { + // 完整路径:messages 中有 thinking 块(非 fast path) + input := []byte(`{ + "thinking":{"type":"enabled","budget_tokens":1024}, + "context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"keep_this"}]}, + "messages":[ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"some thought","signature":"sig"}, + {"type":"text","text":"Answer"} + ]} + ] + }`) + + out := FilterThinkingBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + _, hasThinking := req["thinking"] + require.False(t, hasThinking, "顶层 thinking 应被移除") + + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + edits, ok := cm["edits"].([]any) + require.True(t, ok) + require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留 keep_this") + edit0, ok := edits[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "keep_this", edit0["type"]) +} + +func TestFilterThinkingBlocksForRetry_NoContextManagement_Unaffected(t *testing.T) { + // 无 context_management 时不应报错,且 thinking 正常被移除 + input := []byte(`{ + "thinking":{"type":"enabled"}, + "messages":[{"role":"user","content":[{"type":"text","text":"Hi"}]}] + }`) + + out := FilterThinkingBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + _, hasThinking := req["thinking"] + require.False(t, hasThinking) + _, hasCM := req["context_management"] + require.False(t, hasCM) +} + +// FilterSignatureSensitiveBlocksForRetry — 包含 context_management 的场景 + +func TestFilterSignatureSensitiveBlocksForRetry_RemovesClearThinkingStrategy(t *testing.T) { + input := []byte(`{ + "thinking":{"type":"enabled","budget_tokens":1024}, + "context_management":{"edits":[{"type":"clear_thinking_20251015"}]}, + "messages":[ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"thought","signature":"sig"} + ]} + ] + }`) + + out := FilterSignatureSensitiveBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + _, hasThinking := req["thinking"] + require.False(t, hasThinking, "顶层 thinking 应被移除") + + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + if rawEdits, hasEdits := cm["edits"]; hasEdits { + edits, ok := rawEdits.([]any) + require.True(t, ok) + for _, e := range edits { + em, ok := e.(map[string]any) + require.True(t, ok) + require.NotEqual(t, "clear_thinking_20251015", em["type"], "clear_thinking_20251015 应被移除") + } + } +} + +func TestFilterSignatureSensitiveBlocksForRetry_PreservesNonThinkingStrategies(t *testing.T) { + input := []byte(`{ + "thinking":{"type":"enabled"}, + "context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"other_edit"}]}, + "messages":[ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"t","signature":"s"} + ]} + ] + }`) + + out := FilterSignatureSensitiveBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + edits, ok := cm["edits"].([]any) + require.True(t, ok) + require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留 other_edit") + edit0, ok := edits[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "other_edit", edit0["type"]) +} + +func TestFilterSignatureSensitiveBlocksForRetry_NoThinkingField_ContextManagementUntouched(t *testing.T) { + // 没有顶层 thinking 字段时,context_management 不应被修改 + input := []byte(`{ + "context_management":{"edits":[{"type":"clear_thinking_20251015"}]}, + "messages":[ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"t","signature":"s"} + ]} + ] + }`) + + out := FilterSignatureSensitiveBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + edits, ok := cm["edits"].([]any) + require.True(t, ok) + require.Len(t, edits, 1, "无顶层 thinking 时 context_management 不应被修改") +} + // ============ Group 7: ParseGatewayRequest 补充单元测试 ============ // Task 7.1 — 类型校验边界测试 @@ -768,6 +972,76 @@ func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) { } } +func TestParseGatewayRequest_OutputEffort(t *testing.T) { + tests := []struct { + name string + body string + wantEffort string + }{ + { + name: "output_config.effort present", + body: `{"model":"claude-opus-4-6","output_config":{"effort":"medium"},"messages":[]}`, + wantEffort: "medium", + }, + { + name: "output_config.effort max", + body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`, + wantEffort: "max", + }, + { + name: "output_config without effort", + body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`, + wantEffort: "", + }, + { + name: "no output_config", + body: `{"model":"claude-opus-4-6","messages":[]}`, + wantEffort: "", + }, + { + name: "effort with whitespace trimmed", + body: `{"model":"claude-opus-4-6","output_config":{"effort":" high "},"messages":[]}`, + wantEffort: "high", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parsed, err := ParseGatewayRequest([]byte(tt.body), "") + require.NoError(t, err) + require.Equal(t, tt.wantEffort, parsed.OutputEffort) + }) + } +} + +func TestNormalizeClaudeOutputEffort(t *testing.T) { + tests := []struct { + input string + want *string + }{ + {"low", strPtr("low")}, + {"medium", strPtr("medium")}, + {"high", strPtr("high")}, + {"max", strPtr("max")}, + {"LOW", strPtr("low")}, + {"Max", strPtr("max")}, + {" medium ", strPtr("medium")}, + {"", nil}, + {"unknown", nil}, + {"xhigh", nil}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := NormalizeClaudeOutputEffort(tt.input) + if tt.want == nil { + require.Nil(t, got) + } else { + require.NotNil(t, got) + require.Equal(t, *tt.want, *got) + } + }) + } +} + func BenchmarkParseGatewayRequest_New_Large(b *testing.B) { data := buildLargeJSON() b.SetBytes(int64(len(data))) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 3323f868..0b50162a 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -41,7 +41,7 @@ const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" stickySessionTTL = time.Hour // 粘性会话TTL - defaultMaxLineSize = 40 * 1024 * 1024 + defaultMaxLineSize = 500 * 1024 * 1024 // Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines) // to match real Claude CLI traffic as closely as possible. When we need a visual // separator between system blocks, we add "\n\n" at concatenation time. @@ -50,6 +50,7 @@ const ( defaultUserGroupRateCacheTTL = 30 * time.Second defaultModelsListCacheTTL = 15 * time.Second + postUsageBillingTimeout = 15 * time.Second ) const ( @@ -106,6 +107,36 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() } +func openAIStreamEventIsTerminal(data string) bool { + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if trimmed == "[DONE]" { + return true + } + switch gjson.Get(trimmed, "type").String() { + case "response.completed", "response.done", "response.failed": + return true + default: + return false + } +} + +func anthropicStreamEventIsTerminal(eventName, data string) bool { + if strings.EqualFold(strings.TrimSpace(eventName), "message_stop") { + return true + } + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if trimmed == "[DONE]" { + return true + } + return gjson.Get(trimmed, "type").String() == "message_stop" +} + func cloneStringSlice(src []string) []string { if len(src) == 0 { return nil @@ -315,6 +346,9 @@ var systemBlockFilterPrefixes = []string{ "x-anthropic-billing-header", } +// ErrNoAvailableAccounts 表示没有可用的账号 +var ErrNoAvailableAccounts = errors.New("no available accounts") + // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") @@ -461,6 +495,7 @@ type ForwardResult struct { Duration time.Duration FirstTokenMs *int // 首字时间(流式请求) ClientDisconnect bool // 客户端是否在流式传输过程中断开 + ReasoningEffort *string // 图片生成计费字段(图片生成模型使用) ImageCount int // 生成的图片数量 @@ -501,33 +536,36 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou // GatewayService handles API gateway operations type GatewayService struct { - accountRepo AccountRepository - groupRepo GroupRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - userGroupRateRepo UserGroupRateRepository - cache GatewayCache - digestStore *DigestSessionStore - cfg *config.Config - schedulerSnapshot *SchedulerSnapshotService - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - identityService *IdentityService - httpUpstream HTTPUpstream - deferredService *DeferredService - concurrencyService *ConcurrencyService - claudeTokenProvider *ClaudeTokenProvider - sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) - rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) - userGroupRateCache *gocache.Cache - userGroupRateSF singleflight.Group - modelsListCache *gocache.Cache - modelsListCacheTTL time.Duration - responseHeaderFilter *responseheaders.CompiledHeaderFilter - debugModelRouting atomic.Bool - debugClaudeMimic atomic.Bool + accountRepo AccountRepository + groupRepo GroupRepository + usageLogRepo UsageLogRepository + usageBillingRepo UsageBillingRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache GatewayCache + digestStore *DigestSessionStore + cfg *config.Config + schedulerSnapshot *SchedulerSnapshotService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + identityService *IdentityService + httpUpstream HTTPUpstream + deferredService *DeferredService + concurrencyService *ConcurrencyService + claudeTokenProvider *ClaudeTokenProvider + sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) + userGroupRateResolver *userGroupRateResolver + userGroupRateCache *gocache.Cache + userGroupRateSF singleflight.Group + modelsListCache *gocache.Cache + modelsListCacheTTL time.Duration + settingService *SettingService + responseHeaderFilter *responseheaders.CompiledHeaderFilter + debugModelRouting atomic.Bool + debugClaudeMimic atomic.Bool } // NewGatewayService creates a new GatewayService @@ -535,6 +573,7 @@ func NewGatewayService( accountRepo AccountRepository, groupRepo GroupRepository, usageLogRepo UsageLogRepository, + usageBillingRepo UsageBillingRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, userGroupRateRepo UserGroupRateRepository, @@ -552,6 +591,7 @@ func NewGatewayService( sessionLimitCache SessionLimitCache, rpmCache RPMCache, digestStore *DigestSessionStore, + settingService *SettingService, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) @@ -560,6 +600,7 @@ func NewGatewayService( accountRepo: accountRepo, groupRepo: groupRepo, usageLogRepo: usageLogRepo, + usageBillingRepo: usageBillingRepo, userRepo: userRepo, userSubRepo: userSubRepo, userGroupRateRepo: userGroupRateRepo, @@ -578,10 +619,18 @@ func NewGatewayService( sessionLimitCache: sessionLimitCache, rpmCache: rpmCache, userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), + settingService: settingService, modelsListCache: gocache.New(modelsListTTL, time.Minute), modelsListCacheTTL: modelsListTTL, responseHeaderFilter: compileResponseHeaderFilter(cfg), } + svc.userGroupRateResolver = newUserGroupRateResolver( + userGroupRateRepo, + svc.userGroupRateCache, + userGroupRateTTL, + &svc.userGroupRateSF, + "service.gateway", + ) svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) return svc @@ -986,6 +1035,11 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID) } +// GenerateSessionUUID creates a deterministic UUID4 from a seed string. +func GenerateSessionUUID(seed string) string { + return generateSessionUUID(seed) +} + func generateSessionUUID(seed string) string { if seed == "" { return uuid.NewString() @@ -1154,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return nil, err } if len(accounts) == 0 { - return nil, errors.New("no available accounts") + return nil, ErrNoAvailableAccounts } ctx = s.withWindowCostPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts) @@ -1228,6 +1282,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID) continue } + // 配额检查 + if !s.isAccountSchedulableForQuota(account) { + continue + } // 窗口费用检查(非粘性会话路径) if !s.isAccountSchedulableForWindowCost(ctx, account, false) { filteredWindowCost++ @@ -1260,6 +1318,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && + s.isAccountSchedulableForQuota(stickyAccount) && s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 @@ -1311,7 +1370,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro for _, acc := range routingCandidates { routingLoads = append(routingLoads, AccountWithConcurrency{ ID: acc.ID, - MaxConcurrency: acc.Concurrency, + MaxConcurrency: acc.EffectiveLoadFactor(), }) } routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads) @@ -1416,6 +1475,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro s.isAccountAllowedForPlatform(account, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && + s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查 @@ -1480,6 +1540,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + // 配额检查 + if !s.isAccountSchedulableForQuota(acc) { + continue + } // 窗口费用检查(非粘性会话路径) if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue @@ -1492,14 +1556,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } if len(candidates) == 0 { - return nil, errors.New("no available accounts") + return nil, ErrNoAvailableAccounts } accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) for _, acc := range candidates { accountLoads = append(accountLoads, AccountWithConcurrency{ ID: acc.ID, - MaxConcurrency: acc.Concurrency, + MaxConcurrency: acc.EffectiveLoadFactor(), }) } @@ -1581,7 +1645,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro }, }, nil } - return nil, errors.New("no available accounts") + return nil, ErrNoAvailableAccounts } func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { @@ -1782,8 +1846,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i var err error if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) - } else { + } else if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms) } if err != nil { slog.Debug("account_scheduling_list_failed", @@ -1824,7 +1890,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) // 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询 } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, platform) } if err != nil { slog.Debug("account_scheduling_list_failed", @@ -1964,14 +2030,15 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte } // isAccountInGroup checks if the account belongs to the specified group. -// Returns true if groupID is nil (no group restriction) or account belongs to the group. +// When groupID is nil, returns true only for ungrouped accounts (no group assignments). func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { - if groupID == nil { - return true // 无分组限制 - } if account == nil { return false } + if groupID == nil { + // 无分组的 API Key 只能使用未分组的账号 + return len(account.AccountGroups) == 0 + } for _, ag := range account.AccountGroups { if ag.GroupID == *groupID { return true @@ -2110,6 +2177,15 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts [] return context.WithValue(ctx, windowCostPrefetchContextKey, costs) } +// isAccountSchedulableForQuota 检查账号是否在配额限制内 +// 适用于配置了 quota_limit 的 apikey 和 bedrock 类型账号 +func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool { + if !account.IsAPIKeyOrBedrock() { + return true + } + return !account.IsQuotaExceeded() +} + // isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度 // 仅适用于 Anthropic OAuth/SetupToken 账号 // 返回 true 表示可调度,false 表示不可调度 @@ -2587,7 +2663,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -2641,6 +2717,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForQuota(acc) { + continue + } if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } @@ -2697,7 +2776,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { return account, nil } } @@ -2740,6 +2819,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForQuota(acc) { + continue + } if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } @@ -2773,9 +2855,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if selected == nil { stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false) if requestedModel != "" { - return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats)) + return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats)) } - return nil, errors.New("no available accounts") + return nil, ErrNoAvailableAccounts } // 4. 建立粘性绑定 @@ -2815,7 +2897,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) @@ -2871,6 +2953,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForQuota(acc) { + continue + } if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } @@ -2927,7 +3012,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { return account, nil } @@ -2972,6 +3057,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForQuota(acc) { + continue + } if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } @@ -3005,9 +3093,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if selected == nil { stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true) if requestedModel != "" { - return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats)) + return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats)) } - return nil, errors.New("no available accounts") + return nil, ErrNoAvailableAccounts } // 4. 建立粘性绑定 @@ -3286,6 +3374,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo if account.Platform == PlatformSora { return s.isSoraModelSupportedByAccount(account, requestedModel) } + if account.IsBedrock() { + _, ok := ResolveBedrockModelID(account, requestedModel) + return ok + } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { requestedModel = claude.NormalizeModelID(requestedModel) @@ -3443,6 +3535,8 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( return "", "", errors.New("api_key not found in credentials") } return apiKey, "apikey", nil + case AccountTypeBedrock: + return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理 default: return "", "", fmt.Errorf("unsupported account type: %s", account.Type) } @@ -3886,7 +3980,34 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { - return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime) + passthroughBody := parsed.Body + passthroughModel := parsed.Model + if passthroughModel != "" { + if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel { + passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel) + logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) + passthroughModel = mappedModel + } + } + return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) + } + + if account != nil && account.IsBedrock() { + return s.forwardBedrock(ctx, c, account, parsed, startTime) + } + + // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. + // Always overwrite the cache to prevent stale values from a previous retry with a different account. + if account.Platform == PlatformAnthropic && c != nil { + policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account) + if policy.blockErr != nil { + return nil, policy.blockErr + } + filterSet := policy.filterSet + if filterSet == nil { + filterSet = map[string]struct{}{} + } + c.Set(betaPolicyFilterSetKey, filterSet) } body := parsed.Body @@ -3976,7 +4097,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -4014,7 +4137,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if readErr == nil { _ = resp.Body.Close() - if s.isThinkingBlockSignatureError(respBody) { + if s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -4054,7 +4177,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // also downgrade tool_use/tool_result blocks to text. filteredBody := FilterThinkingBlocksForRetry(body) - retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + retryCtx, releaseRetryCtx := detachStreamUpstreamContext(ctx, reqStream) + retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseRetryCtx() if buildErr == nil { retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { @@ -4086,7 +4211,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) - retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + retryCtx2, releaseRetryCtx2 := detachStreamUpstreamContext(ctx, reqStream) + retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseRetryCtx2() if buildErr2 == nil { retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr2 == nil { @@ -4131,7 +4258,47 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) break } - // 不是thinking签名错误,恢复响应体 + // 不是签名错误(或整流器已关闭),继续检查 budget 约束 + errMsg := extractUpstreamErrorMessage(respBody) + if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "budget_constraint_error", + Message: errMsg, + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + + rectifiedBody, applied := RectifyThinkingBudget(body) + if applied && time.Since(retryStart) < maxRetryElapsed { + logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens) + budgetRetryCtx, releaseBudgetRetryCtx := detachStreamUpstreamContext(ctx, reqStream) + budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseBudgetRetryCtx() + if buildErr == nil { + budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if retryErr == nil { + resp = budgetRetryResp + break + } + if budgetRetryResp != nil && budgetRetryResp.Body != nil { + _ = budgetRetryResp.Body.Close() + } + logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry failed: %v", account.ID, retryErr) + } else { + logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry build failed: %v", account.ID, buildErr) + } + } + } + resp.Body = io.NopCloser(bytes.NewReader(respBody)) } } @@ -4223,7 +4390,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } } return s.handleRetryExhaustedError(ctx, resp, c, account) } @@ -4253,7 +4424,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } } if resp.StatusCode >= 400 { // 可选:对部分 400 触发 failover(默认关闭以保持语义) @@ -4305,6 +4480,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 处理正常响应 + + // 触发上游接受回调(提前释放串行锁,不等流完成) + if parsed.OnUpstreamAccepted != nil { + parsed.OnUpstreamAccepted() + } + var usage *ClaudeUsage var firstTokenMs *int var clientDisconnect bool @@ -4373,7 +4554,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( var resp *http.Response retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { - upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -4482,7 +4665,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } } return s.handleRetryExhaustedError(ctx, resp, c, account) } @@ -4512,7 +4699,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } } if resp.StatusCode >= 400 { @@ -4565,7 +4756,7 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough( if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages" + targetURL = validatedURL + "/v1/messages?beta=true" } req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) @@ -4641,6 +4832,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( usage := &ClaudeUsage{} var firstTokenMs *int clientDisconnected := false + sawTerminalEvent := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -4703,17 +4895,20 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 flusher.Flush() } + if !sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") + } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil } if ev.err != nil { + if sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } if clientDisconnected { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) } if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v", - account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err()) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) } if errors.Is(ev.err, bufio.ErrTooLong) { logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) @@ -4725,11 +4920,19 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( line := ev.line if data, ok := extractAnthropicSSEDataLine(line); ok { trimmed := strings.TrimSpace(data) + if anthropicStreamEventIsTerminal("", trimmed) { + sawTerminalEvent = true + } if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } s.parseSSEUsagePassthrough(data, usage) + } else { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") { + sawTerminalEvent = true + } } if !clientDisconnected { @@ -4751,8 +4954,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( continue } if clientDisconnected { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") } logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) if s.rateLimitService != nil { @@ -4935,6 +5137,368 @@ func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, } } +// forwardBedrock 转发请求到 AWS Bedrock +func (s *GatewayService) forwardBedrock( + ctx context.Context, + c *gin.Context, + account *Account, + parsed *ParsedRequest, + startTime time.Time, +) (*ForwardResult, error) { + reqModel := parsed.Model + reqStream := parsed.Stream + body := parsed.Body + + region := bedrockRuntimeRegion(account) + mappedModel, ok := ResolveBedrockModelID(account, reqModel) + if !ok { + return nil, fmt.Errorf("unsupported bedrock model: %s", reqModel) + } + if mappedModel != reqModel { + logger.LegacyPrintf("service.gateway", "[Bedrock] Model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) + } + + betaHeader := "" + if c != nil && c.Request != nil { + betaHeader = c.GetHeader("anthropic-beta") + } + + // 准备请求体(注入 anthropic_version/anthropic_beta,移除 Bedrock 不支持的字段,清理 cache_control) + betaTokens, err := s.resolveBedrockBetaTokensForRequest(ctx, account, betaHeader, body, mappedModel) + if err != nil { + return nil, err + } + + bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens) + if err != nil { + return nil, fmt.Errorf("prepare bedrock request body: %w", err) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + logger.LegacyPrintf("service.gateway", "[Bedrock] 命中 Bedrock 分支: account=%d name=%s model=%s->%s stream=%v", + account.ID, account.Name, reqModel, mappedModel, reqStream) + + // 根据账号类型选择认证方式 + var signer *BedrockSigner + var bedrockAPIKey string + if account.IsBedrockAPIKey() { + bedrockAPIKey = account.GetCredential("api_key") + if bedrockAPIKey == "" { + return nil, fmt.Errorf("api_key not found in bedrock credentials") + } + } else { + signer, err = NewBedrockSignerFromAccount(account) + if err != nil { + return nil, fmt.Errorf("create bedrock signer: %w", err) + } + } + + // 执行上游请求(含重试) + resp, err := s.executeBedrockUpstream(ctx, c, account, bedrockBody, mappedModel, region, reqStream, signer, bedrockAPIKey, proxyURL) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + // 将 Bedrock 的 x-amzn-requestid 映射到 x-request-id, + // 使通用错误处理函数(handleErrorResponse、handleRetryExhaustedError)能正确提取 AWS request ID。 + if awsReqID := resp.Header.Get("x-amzn-requestid"); awsReqID != "" && resp.Header.Get("x-request-id") == "" { + resp.Header.Set("x-request-id", awsReqID) + } + + // 错误/failover 处理 + if resp.StatusCode >= 400 { + return s.handleBedrockUpstreamErrors(ctx, resp, c, account) + } + + // 响应处理 + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + if reqStream { + streamResult, err := s.handleBedrockStreamingResponse(ctx, resp, c, account, startTime, reqModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect + } else { + usage, err = s.handleBedrockNonStreamingResponse(ctx, resp, c, account) + if err != nil { + return nil, err + } + } + if usage == nil { + usage = &ClaudeUsage{} + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-amzn-requestid"), + Usage: *usage, + Model: reqModel, + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + }, nil +} + +// executeBedrockUpstream 执行 Bedrock 上游请求(含重试逻辑) +func (s *GatewayService) executeBedrockUpstream( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + modelID string, + region string, + stream bool, + signer *BedrockSigner, + apiKey string, + proxyURL string, +) (*http.Response, error) { + var resp *http.Response + var err error + retryStart := time.Now() + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + var upstreamReq *http.Request + if account.IsBedrockAPIKey() { + upstreamReq, err = s.buildUpstreamRequestBedrockAPIKey(ctx, body, modelID, region, stream, apiKey) + } else { + upstreamReq, err = s.buildUpstreamRequestBedrock(ctx, body, modelID, region, stream, signer) + } + if err != nil { + return nil, err + } + + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, false) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetryAttempts { + elapsed := time.Since(retryStart) + if elapsed >= maxRetryElapsed { + break + } + + delay := retryBackoffDelay(attempt) + remaining := maxRetryElapsed - elapsed + if delay > remaining { + delay = remaining + } + if delay <= 0 { + break + } + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "retry", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + logger.LegacyPrintf("service.gateway", "[Bedrock] account %d: upstream error %d, retry %d/%d after %v", + account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + break + } + + break + } + if resp == nil || resp.Body == nil { + return nil, errors.New("upstream request failed: empty response") + } + return resp, nil +} + +// handleBedrockUpstreamErrors 处理 Bedrock 上游 4xx/5xx 错误(failover + 错误响应) +func (s *GatewayService) handleBedrockUpstreamErrors( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ForwardResult, error) { + // retry exhausted + failover + if s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Bedrock] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d Body=%s", + account.ID, account.Name, resp.StatusCode, truncateString(string(respBody), 1000)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "retry_exhausted_failover", + Message: extractUpstreamErrorMessage(respBody), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + + // non-retryable failover + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "failover", + Message: extractUpstreamErrorMessage(respBody), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + + // other errors + return s.handleErrorResponse(ctx, resp, c, account) +} + +// buildUpstreamRequestBedrock 构建 Bedrock 上游请求 +func (s *GatewayService) buildUpstreamRequestBedrock( + ctx context.Context, + body []byte, + modelID string, + region string, + stream bool, + signer *BedrockSigner, +) (*http.Request, error) { + targetURL := BuildBedrockURL(region, modelID, stream) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + // SigV4 签名 + if err := signer.SignRequest(ctx, req, body); err != nil { + return nil, fmt.Errorf("sign bedrock request: %w", err) + } + + return req, nil +} + +// buildUpstreamRequestBedrockAPIKey 构建 Bedrock API Key (Bearer Token) 上游请求 +func (s *GatewayService) buildUpstreamRequestBedrockAPIKey( + ctx context.Context, + body []byte, + modelID string, + region string, + stream bool, + apiKey string, +) (*http.Request, error) { + targetURL := BuildBedrockURL(region, modelID, stream) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + return req, nil +} + +// handleBedrockNonStreamingResponse 处理 Bedrock 非流式响应 +// Bedrock InvokeModel 非流式响应的 body 格式与 Claude API 兼容 +func (s *GatewayService) handleBedrockNonStreamingResponse( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ClaudeUsage, error) { + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + // 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式 + // 并移除该字段避免透传给客户端 + body = transformBedrockInvocationMetrics(body) + + usage := parseClaudeUsageFromResponseBody(body) + + c.Header("Content-Type", "application/json") + if v := resp.Header.Get("x-amzn-requestid"); v != "" { + c.Header("x-request-id", v) + } + c.Data(resp.StatusCode, "application/json", body) + return usage, nil +} + func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL @@ -4945,7 +5509,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages" + targetURL = validatedURL + "/v1/messages?beta=true" } } @@ -5014,6 +5578,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex applyClaudeOAuthHeaderDefaults(req, reqStream) } + // Build effective drop set: merge static defaults with dynamic beta policy filter rules + policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account) + effectiveDropSet := mergeDropSets(policyFilterSet) + effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode) + // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) if tokenType == "oauth" { if mimicClaudeCode { @@ -5027,17 +5596,22 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // messages requests typically use only oauth + interleaved-thinking. // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, droppedBetasWithClaudeCodeSet)) + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := req.Header.Get("anthropic-beta") - req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), defaultDroppedBetasSet)) + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet)) } - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { - // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) - if requestNeedsBetaFeatures(body) { - if beta := defaultAPIKeyBetaHeader(body); beta != "" { - req.Header.Set("anthropic-beta", beta) + } else { + // API-key accounts: apply beta policy filter to strip controlled tokens + if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" { + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet)) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey { + // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultAPIKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } } } } @@ -5215,6 +5789,107 @@ func stripBetaTokensWithSet(header string, drop map[string]struct{}) string { return strings.Join(out, ",") } +// BetaBlockedError indicates a request was blocked by a beta policy rule. +type BetaBlockedError struct { + Message string +} + +func (e *BetaBlockedError) Error() string { return e.Message } + +// betaPolicyResult holds the evaluated result of beta policy rules for a single request. +type betaPolicyResult struct { + blockErr *BetaBlockedError // non-nil if a block rule matched + filterSet map[string]struct{} // tokens to filter (may be nil) +} + +// evaluateBetaPolicy loads settings once and evaluates all rules against the given request. +func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult { + if s.settingService == nil { + return betaPolicyResult{} + } + settings, err := s.settingService.GetBetaPolicySettings(ctx) + if err != nil || settings == nil { + return betaPolicyResult{} + } + isOAuth := account.IsOAuth() + isBedrock := account.IsBedrock() + var result betaPolicyResult + for _, rule := range settings.Rules { + if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { + continue + } + switch rule.Action { + case BetaPolicyActionBlock: + if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) { + msg := rule.ErrorMessage + if msg == "" { + msg = "beta feature " + rule.BetaToken + " is not allowed" + } + result.blockErr = &BetaBlockedError{Message: msg} + } + case BetaPolicyActionFilter: + if result.filterSet == nil { + result.filterSet = make(map[string]struct{}) + } + result.filterSet[rule.BetaToken] = struct{}{} + } + } + return result +} + +// mergeDropSets merges the static defaultDroppedBetasSet with dynamic policy filter tokens. +// Returns defaultDroppedBetasSet directly when policySet is empty (zero allocation). +func mergeDropSets(policySet map[string]struct{}, extra ...string) map[string]struct{} { + if len(policySet) == 0 && len(extra) == 0 { + return defaultDroppedBetasSet + } + m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(policySet)+len(extra)) + for t := range defaultDroppedBetasSet { + m[t] = struct{}{} + } + for t := range policySet { + m[t] = struct{}{} + } + for _, t := range extra { + m[t] = struct{}{} + } + return m +} + +// betaPolicyFilterSetKey is the gin.Context key for caching the policy filter set within a request. +const betaPolicyFilterSetKey = "betaPolicyFilterSet" + +// getBetaPolicyFilterSet returns the beta policy filter set, using the gin context cache if available. +// In the /v1/messages path, Forward() evaluates the policy first and caches the result; +// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this +// evaluates on demand (one DB call). +func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} { + if c != nil { + if v, ok := c.Get(betaPolicyFilterSetKey); ok { + if fs, ok := v.(map[string]struct{}); ok { + return fs + } + } + } + return s.evaluateBetaPolicy(ctx, "", account).filterSet +} + +// betaPolicyScopeMatches checks whether a rule's scope matches the current account type. +func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool { + switch scope { + case BetaPolicyScopeAll: + return true + case BetaPolicyScopeOAuth: + return isOAuth + case BetaPolicyScopeAPIKey: + return !isOAuth && !isBedrock + case BetaPolicyScopeBedrock: + return isBedrock + default: + return true // unknown scope → match all (fail-open) + } +} + // droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. func droppedBetaSet(extra ...string) map[string]struct{} { m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) @@ -5227,6 +5902,90 @@ func droppedBetaSet(extra ...string) map[string]struct{} { return m } +// containsBetaToken checks if a comma-separated header value contains the given token. +func containsBetaToken(header, token string) bool { + if header == "" || token == "" { + return false + } + for _, p := range strings.Split(header, ",") { + if strings.TrimSpace(p) == token { + return true + } + } + return false +} + +func filterBetaTokens(tokens []string, filterSet map[string]struct{}) []string { + if len(tokens) == 0 || len(filterSet) == 0 { + return tokens + } + kept := make([]string, 0, len(tokens)) + for _, token := range tokens { + if _, filtered := filterSet[token]; !filtered { + kept = append(kept, token) + } + } + return kept +} + +func (s *GatewayService) resolveBedrockBetaTokensForRequest( + ctx context.Context, + account *Account, + betaHeader string, + body []byte, + modelID string, +) ([]string, error) { + // 1. 对原始 header 中的 beta token 做 block 检查(快速失败) + policy := s.evaluateBetaPolicy(ctx, betaHeader, account) + if policy.blockErr != nil { + return nil, policy.blockErr + } + + // 2. 解析 header + body 自动注入 + Bedrock 转换/过滤 + betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID) + + // 3. 对最终 token 列表再做 block 检查,捕获通过 body 自动注入绕过 header block 的情况。 + // 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, + // 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 → + // 如果不做此检查,block 规则会被绕过。 + if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil { + return nil, blockErr + } + + return filterBetaTokens(betaTokens, policy.filterSet), nil +} + +// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。 +// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。 +func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError { + if s.settingService == nil || len(tokens) == 0 { + return nil + } + settings, err := s.settingService.GetBetaPolicySettings(ctx) + if err != nil || settings == nil { + return nil + } + isOAuth := account.IsOAuth() + isBedrock := account.IsBedrock() + tokenSet := buildBetaTokenSet(tokens) + for _, rule := range settings.Rules { + if rule.Action != BetaPolicyActionBlock { + continue + } + if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { + continue + } + if _, present := tokenSet[rule.BetaToken]; present { + msg := rule.ErrorMessage + if msg == "" { + msg = "beta feature " + rule.BetaToken + " is not allowed" + } + return &BetaBlockedError{Message: msg} + } + } + return nil +} + func buildBetaTokenSet(tokens []string) map[string]struct{} { m := make(map[string]struct{}, len(tokens)) for _, t := range tokens { @@ -5238,10 +5997,7 @@ func buildBetaTokenSet(tokens []string) map[string]struct{} { return m } -var ( - defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas) - droppedBetasWithClaudeCodeSet = droppedBetaSet(claude.BetaClaudeCode) -) +var defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas) // applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. // This mirrors opencode-anthropic-auth behavior: do not trust downstream @@ -5368,10 +6124,38 @@ func extractUpstreamErrorMessage(body []byte) string { return m } + // ChatGPT 内部 API 风格:{"detail":"..."} + if d := gjson.GetBytes(body, "detail").String(); strings.TrimSpace(d) != "" { + return d + } + // 兜底:尝试顶层 message return gjson.GetBytes(body, "message").String() } +func extractUpstreamErrorCode(body []byte) string { + if code := strings.TrimSpace(gjson.GetBytes(body, "error.code").String()); code != "" { + return code + } + + inner := strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) + if !strings.HasPrefix(inner, "{") { + return "" + } + + if code := strings.TrimSpace(gjson.Get(inner, "error.code").String()); code != "" { + return code + } + + if lastBrace := strings.LastIndex(inner, "}"); lastBrace >= 0 { + if code := strings.TrimSpace(gjson.Get(inner[:lastBrace+1], "error.code").String()); code != "" { + return code + } + } + + return "" +} + func isCountTokensUnsupported404(statusCode int, body []byte) bool { if statusCode != http.StatusNotFound { return false @@ -5742,6 +6526,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http intervalCh = intervalTicker.C } + // 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开 + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + lastDataAt := time.Now() + // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) errorEventSent := false sendErrorEvent := func(reason string) { @@ -5755,6 +6555,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage + sawTerminalEvent := false pendingEventLines := make([]string, 0, 4) @@ -5785,6 +6586,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } if dataLine == "[DONE]" { + sawTerminalEvent = true block := "" if eventName != "" { block = "event: " + eventName + "\n" @@ -5851,6 +6653,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } usagePatch := s.extractSSEUsagePatch(event) + if anthropicStreamEventIsTerminal(eventName, dataLine) { + sawTerminalEvent = true + } if !eventChanged { block := "" if eventName != "" { @@ -5884,18 +6689,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http case ev, ok := <-events: if !ok { // 上游完成,返回结果 + if !sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") + } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil } if ev.err != nil { + if sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage") - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) } // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage if clientDisconnected { - logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) } // 客户端未断开,正常的错误处理 if errors.Is(ev.err, bufio.ErrTooLong) { @@ -5931,6 +6740,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http break } flusher.Flush() + lastDataAt = time.Now() } if data != "" { if firstTokenMs == nil && data != "[DONE]" { @@ -5953,9 +6763,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http continue } if clientDisconnected { - // 客户端已断开,上游也超时了,返回已收集的 usage - logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage") - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") } logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 @@ -5964,6 +6772,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } sendErrorEvent("stream_timeout") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // SSE ping 事件:Anthropic 原生格式,客户端会正确处理, + // 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开 + if _, werr := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); werr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "Client disconnected during keepalive ping, continuing to drain upstream for billing") + continue + } + flusher.Flush() } } @@ -6283,81 +7107,318 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo } func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { - if s == nil || userID <= 0 || groupID <= 0 { + if s == nil { return groupDefaultMultiplier } - - key := fmt.Sprintf("%d:%d", userID, groupID) - if s.userGroupRateCache != nil { - if cached, ok := s.userGroupRateCache.Get(key); ok { - if multiplier, castOK := cached.(float64); castOK { - userGroupRateCacheHitTotal.Add(1) - return multiplier - } - } + resolver := s.userGroupRateResolver + if resolver == nil { + resolver = newUserGroupRateResolver( + s.userGroupRateRepo, + s.userGroupRateCache, + resolveUserGroupRateCacheTTL(s.cfg), + &s.userGroupRateSF, + "service.gateway", + ) } - if s.userGroupRateRepo == nil { - return groupDefaultMultiplier - } - userGroupRateCacheMissTotal.Add(1) - - value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) { - if s.userGroupRateCache != nil { - if cached, ok := s.userGroupRateCache.Get(key); ok { - if multiplier, castOK := cached.(float64); castOK { - userGroupRateCacheHitTotal.Add(1) - return multiplier, nil - } - } - } - - userGroupRateCacheLoadTotal.Add(1) - userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID) - if repoErr != nil { - return nil, repoErr - } - multiplier := groupDefaultMultiplier - if userRate != nil { - multiplier = *userRate - } - if s.userGroupRateCache != nil { - s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg)) - } - return multiplier, nil - }) - if shared { - userGroupRateCacheSFSharedTotal.Add(1) - } - if err != nil { - userGroupRateCacheFallbackTotal.Add(1) - logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err) - return groupDefaultMultiplier - } - - multiplier, ok := value.(float64) - if !ok { - userGroupRateCacheFallbackTotal.Add(1) - return groupDefaultMultiplier - } - return multiplier + return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier) } // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { - Result *ForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription // 可选:订阅信息 - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) - APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + InboundEndpoint string // 入站端点(客户端请求路径) + UpstreamEndpoint string // 上游端点(标准化后的上游路径) + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 } -// APIKeyQuotaUpdater defines the interface for updating API Key quota +// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage type APIKeyQuotaUpdater interface { UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error + UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error +} + +type apiKeyAuthCacheInvalidator interface { + InvalidateAuthCacheByKey(ctx context.Context, key string) +} + +type usageLogBestEffortWriter interface { + CreateBestEffort(ctx context.Context, log *UsageLog) error +} + +// postUsageBillingParams 统一扣费所需的参数 +type postUsageBillingParams struct { + Cost *CostBreakdown + User *User + APIKey *APIKey + Account *Account + Subscription *UserSubscription + RequestPayloadHash string + IsSubscriptionBill bool + AccountRateMultiplier float64 + APIKeyService APIKeyQuotaUpdater +} + +// postUsageBilling 统一处理使用量记录后的扣费逻辑: +// - 订阅/余额扣费 +// - API Key 配额更新 +// - API Key 限速用量更新 +// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率) +func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { + billingCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + cost := p.Cost + + // 1. 订阅 / 余额扣费 + if p.IsSubscriptionBill { + if cost.TotalCost > 0 { + if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { + slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) + } + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost) + } + } else { + if cost.ActualCost > 0 { + if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil { + slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) + } + deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost) + } + } + + // 2. API Key 配额 + if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { + slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) + } + } + + // 3. API Key 限速用量 + if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { + slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) + } + } + + // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) + if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { + accountCost := cost.TotalCost * p.AccountRateMultiplier + if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { + slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) + } + } + + finalizePostUsageBilling(p, deps) +} + +func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { + if ctx != nil { + if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + return "client:" + strings.TrimSpace(clientRequestID) + } + if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + return "local:" + strings.TrimSpace(requestID) + } + } + if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" { + return requestID + } + return "generated:" + generateRequestID() +} + +func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string { + if payloadHash := strings.TrimSpace(requestPayloadHash); payloadHash != "" { + return payloadHash + } + if ctx != nil { + if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + return "client:" + strings.TrimSpace(clientRequestID) + } + if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + return "local:" + strings.TrimSpace(requestID) + } + } + return "" +} + +func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsageBillingParams) *UsageBillingCommand { + if p == nil || p.Cost == nil || p.APIKey == nil || p.User == nil || p.Account == nil { + return nil + } + + cmd := &UsageBillingCommand{ + RequestID: requestID, + APIKeyID: p.APIKey.ID, + UserID: p.User.ID, + AccountID: p.Account.ID, + AccountType: p.Account.Type, + RequestPayloadHash: strings.TrimSpace(p.RequestPayloadHash), + } + if usageLog != nil { + cmd.Model = usageLog.Model + cmd.BillingType = usageLog.BillingType + cmd.InputTokens = usageLog.InputTokens + cmd.OutputTokens = usageLog.OutputTokens + cmd.CacheCreationTokens = usageLog.CacheCreationTokens + cmd.CacheReadTokens = usageLog.CacheReadTokens + cmd.ImageCount = usageLog.ImageCount + if usageLog.MediaType != nil { + cmd.MediaType = *usageLog.MediaType + } + if usageLog.ServiceTier != nil { + cmd.ServiceTier = *usageLog.ServiceTier + } + if usageLog.ReasoningEffort != nil { + cmd.ReasoningEffort = *usageLog.ReasoningEffort + } + if usageLog.SubscriptionID != nil { + cmd.SubscriptionID = usageLog.SubscriptionID + } + } + + if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { + cmd.SubscriptionID = &p.Subscription.ID + cmd.SubscriptionCost = p.Cost.TotalCost + } else if p.Cost.ActualCost > 0 { + cmd.BalanceCost = p.Cost.ActualCost + } + + if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + cmd.APIKeyQuotaCost = p.Cost.ActualCost + } + if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + cmd.APIKeyRateLimitCost = p.Cost.ActualCost + } + if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { + cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier + } + + cmd.Normalize() + return cmd +} + +func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog, p *postUsageBillingParams, deps *billingDeps, repo UsageBillingRepository) (bool, error) { + if p == nil || deps == nil { + return false, nil + } + + cmd := buildUsageBillingCommand(requestID, usageLog, p) + if cmd == nil || cmd.RequestID == "" || repo == nil { + postUsageBilling(ctx, p, deps) + return true, nil + } + + billingCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + result, err := repo.Apply(billingCtx, cmd) + if err != nil { + return false, err + } + + if result == nil || !result.Applied { + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) + return false, nil + } + + if result.APIKeyQuotaExhausted { + if invalidator, ok := p.APIKeyService.(apiKeyAuthCacheInvalidator); ok && p.APIKey != nil && p.APIKey.Key != "" { + invalidator.InvalidateAuthCacheByKey(billingCtx, p.APIKey.Key) + } + } + + finalizePostUsageBilling(p, deps) + return true, nil +} + +func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { + if p == nil || p.Cost == nil || deps == nil { + return + } + + if p.IsSubscriptionBill { + if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost) + } + } else if p.Cost.ActualCost > 0 && p.User != nil { + deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) + } + + if p.Cost.ActualCost > 0 && p.APIKey != nil && p.APIKey.HasRateLimits() { + deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, p.Cost.ActualCost) + } + + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) +} + +func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) { + base := context.Background() + if ctx != nil { + base = context.WithoutCancel(ctx) + } + return context.WithTimeout(base, postUsageBillingTimeout) +} + +func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if !stream { + return ctx, func() {} + } + if ctx == nil { + return context.Background(), func() {} + } + return context.WithoutCancel(ctx), func() {} +} + +// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) +type billingDeps struct { + accountRepo AccountRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + billingCacheService *BillingCacheService + deferredService *DeferredService +} + +func (s *GatewayService) billingDeps() *billingDeps { + return &billingDeps{ + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + } +} + +func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usageLog *UsageLog, logKey string) { + if repo == nil || usageLog == nil { + return + } + usageCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + if writer, ok := repo.(usageLogBestEffortWriter); ok { + if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil { + logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + if IsUsageLogCreateDropped(err) { + return + } + if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil { + logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr) + } + } + return + } + + if _, err := repo.Create(usageCtx, usageLog); err != nil { + logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + } } // RecordUsage 记录使用量并扣费(或更新订阅用量) @@ -6461,12 +7522,16 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu mediaType = &result.MediaType } accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: result.Model, + ReasoningEffort: result.ReasoningEffort, + InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), + UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, @@ -6510,49 +7575,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu usageLog.SubscriptionID = &subscription.ID } - inserted, err := s.usageLogRepo.Create(ctx, usageLog) - if err != nil { - logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) - } - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps(), s.usageBillingRepo) + return err + }() - // 根据计费类型执行扣费 - if isSubscriptionBilling { - // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) - if shouldBill && cost.TotalCost > 0 { - if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) - } - // 异步更新订阅缓存 - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } - } else { - // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) - if shouldBill && cost.ActualCost > 0 { - if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) - } - // 异步更新余额缓存 - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - } + if billingErr != nil { + return billingErr } - - // 更新 API Key 配额(如果设置了配额限制) - if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err) - } - } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") return nil } @@ -6563,13 +7611,16 @@ type RecordUsageLongContextInput struct { APIKey *APIKey User *User Account *Account - Subscription *UserSubscription // 可选:订阅信息 - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - LongContextThreshold int // 长上下文阈值(如 200000) - LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) - ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) - APIKeyService *APIKeyService // API Key 配额服务(可选) + Subscription *UserSubscription // 可选:订阅信息 + InboundEndpoint string // 入站端点(客户端请求路径) + UpstreamEndpoint string // 上游端点(标准化后的上游路径) + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + LongContextThreshold int // 长上下文阈值(如 200000) + LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) } // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) @@ -6652,12 +7703,16 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * imageSize = &result.ImageSize } accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: result.Model, + ReasoningEffort: result.ReasoningEffort, + InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), + UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, @@ -6700,48 +7755,32 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * usageLog.SubscriptionID = &subscription.ID } - inserted, err := s.usageLogRepo.Create(ctx, usageLog) - if err != nil { - logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) - } - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps(), s.usageBillingRepo) + return err + }() - // 根据计费类型执行扣费 - if isSubscriptionBilling { - // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) - if shouldBill && cost.TotalCost > 0 { - if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) - } - // 异步更新订阅缓存 - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } - } else { - // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) - if shouldBill && cost.ActualCost > 0 { - if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) - } - // 异步更新余额缓存 - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - // API Key 独立配额扣费 - if input.APIKeyService != nil && apiKey.Quota > 0 { - if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err) - } - } - } + if billingErr != nil { + return billingErr } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") return nil } @@ -6755,7 +7794,20 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { - return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body) + passthroughBody := parsed.Body + if reqModel := parsed.Model; reqModel != "" { + if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel { + passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel) + logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) + } + } + return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody) + } + + // Bedrock 不支持 count_tokens 端点 + if account != nil && account.IsBedrock() { + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for Bedrock") + return nil } body := parsed.Body @@ -6845,7 +7897,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks) - if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) { + if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) { logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) filteredBody := FilterThinkingBlocksForRetry(body) @@ -7046,7 +8098,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough( if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages/count_tokens" + targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" } req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) @@ -7093,7 +8145,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages/count_tokens" + targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" } } @@ -7157,6 +8209,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con applyClaudeOAuthHeaderDefaults(req, false) } + // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules + ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account)) + // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { if mimicClaudeCode { @@ -7164,8 +8219,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con incomingBeta := req.Header.Get("anthropic-beta") requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} - drop := droppedBetaSet() - req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet)) } else { clientBetaHeader := req.Header.Get("anthropic-beta") if clientBetaHeader == "" { @@ -7175,14 +8229,19 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if !strings.Contains(beta, claude.BetaTokenCounting) { beta = beta + "," + claude.BetaTokenCounting } - req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, defaultDroppedBetasSet)) + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet)) } } - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { - // API-key:与 messages 同步的按需 beta 注入(默认关闭) - if requestNeedsBetaFeatures(body) { - if beta := defaultAPIKeyBetaHeader(body); beta != "" { - req.Header.Set("anthropic-beta", beta) + } else { + // API-key accounts: apply beta policy filter to strip controlled tokens + if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" { + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet)) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey { + // API-key:与 messages 同步的按需 beta 注入(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultAPIKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } } } } diff --git a/backend/internal/service/gateway_service_bedrock_beta_test.go b/backend/internal/service/gateway_service_bedrock_beta_test.go new file mode 100644 index 00000000..8920ee08 --- /dev/null +++ b/backend/internal/service/gateway_service_bedrock_beta_test.go @@ -0,0 +1,267 @@ +package service + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +type betaPolicySettingRepoStub struct { + values map[string]string +} + +func (s *betaPolicySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *betaPolicySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", ErrSettingNotFound +} + +func (s *betaPolicySettingRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *betaPolicySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *betaPolicySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *betaPolicySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *betaPolicySettingRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestResolveBedrockBetaTokensForRequest_BlocksOnOriginalAnthropicToken(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "advanced-tool-use-2025-11-20", + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "advanced tool use is blocked", + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + _, err = svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "advanced-tool-use-2025-11-20", + []byte(`{"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-opus-4-6-v1", + ) + if err == nil { + t.Fatal("expected raw advanced-tool-use token to be blocked before Bedrock transform") + } + if err.Error() != "advanced tool use is blocked" { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestResolveBedrockBetaTokensForRequest_FiltersAfterBedrockTransform(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "tool-search-tool-2025-10-19", + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + betaTokens, err := svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "advanced-tool-use-2025-11-20", + []byte(`{"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-opus-4-6-v1", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, token := range betaTokens { + if token == "tool-search-tool-2025-10-19" { + t.Fatalf("expected transformed Bedrock token to be filtered") + } + } +} + +// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking 验证: +// 管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, +// 但请求体包含 thinking 字段 → 自动注入后应被 block。 +func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "interleaved-thinking-2025-05-14", + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "thinking is blocked", + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + // header 中不带 beta token,但 body 中有 thinking 字段 + _, err = svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "", // 空 header + []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-opus-4-6-v1", + ) + if err == nil { + t.Fatal("expected body-injected interleaved-thinking to be blocked") + } + if err.Error() != "thinking is blocked" { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch 验证: +// 管理员 block 了 tool-search-tool,客户端不在 header 中带 beta token, +// 但请求体包含 tool search 工具 → 自动注入后应被 block。 +func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "tool-search-tool-2025-10-19", + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "tool search is blocked", + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + // header 中不带 beta token,但 body 中有 tool_search_tool 工具 + _, err = svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "", + []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-sonnet-4-6", + ) + if err == nil { + t.Fatal("expected body-injected tool-search-tool to be blocked") + } + if err.Error() != "tool search is blocked" { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches 验证: +// body 自动注入的 token 如果没有对应的 block 规则,应正常通过。 +func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "computer-use-2025-11-24", + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "computer use is blocked", + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + // body 中有 thinking(会注入 interleaved-thinking),但 block 规则只针对 computer-use + tokens, err := svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "", + []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-opus-4-6-v1", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + found := false + for _, token := range tokens { + if token == "interleaved-thinking-2025-05-14" { + found = true + } + } + if !found { + t.Fatal("expected interleaved-thinking token to be present") + } +} diff --git a/backend/internal/service/gateway_service_bedrock_model_support_test.go b/backend/internal/service/gateway_service_bedrock_model_support_test.go new file mode 100644 index 00000000..aa8d4756 --- /dev/null +++ b/backend/internal/service/gateway_service_bedrock_model_support_test.go @@ -0,0 +1,48 @@ +package service + +import "testing" + +func TestGatewayServiceIsModelSupportedByAccount_BedrockDefaultMappingRestrictsModels(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "us-east-1", + }, + } + + if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-5") { + t.Fatalf("expected default Bedrock alias to be supported") + } + + if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") { + t.Fatalf("expected unsupported alias to be rejected for Bedrock account") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_BedrockCustomMappingStillActsAsAllowlist(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "eu-west-1", + "model_mapping": map[string]any{ + "claude-sonnet-*": "claude-sonnet-4-6", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-6") { + t.Fatalf("expected matched custom mapping to be supported") + } + + if !svc.isModelSupportedByAccount(account, "claude-opus-4-6") { + t.Fatalf("expected default Bedrock alias fallback to remain supported") + } + + if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") { + t.Fatalf("expected unsupported model to still be rejected") + } +} diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go index cd690cbd..b1584827 100644 --- a/backend/internal/service/gateway_streaming_test.go +++ b/backend/internal/service/gateway_streaming_test.go @@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) { result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) _ = pr.Close() - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") require.NotNil(t, result) } diff --git a/backend/internal/service/gemini_error_policy_test.go b/backend/internal/service/gemini_error_policy_test.go index 2ce8793a..4bd1ced7 100644 --- a/backend/internal/service/gemini_error_policy_test.go +++ b/backend/internal/service/gemini_error_policy_test.go @@ -122,6 +122,28 @@ func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) { body: []byte(`overloaded service`), expected: ErrorPolicyTempUnscheduled, }, + { + name: "gemini_apikey_temp_unschedulable_401_second_hit_returns_none", + account: &Account{ + ID: 105, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 401, + body: []byte(`unauthorized`), + expected: ErrorPolicyNone, + }, { name: "gemini_custom_codes_override_temp_unschedulable", account: &Account{ diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 1c38b6c2..e65c838d 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -431,7 +431,10 @@ func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Co if groupID != nil { return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms) } - return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) + } + return s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, queryPlatforms) } func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) { @@ -3232,7 +3235,7 @@ func cleanToolSchema(schema any) any { for key, value := range v { // 跳过不支持的字段 if key == "$schema" || key == "$id" || key == "$ref" || - key == "additionalProperties" || key == "minLength" || + key == "additionalProperties" || key == "patternProperties" || key == "minLength" || key == "maxLength" || key == "minItems" || key == "maxItems" { continue } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 86bc9476..b0b804eb 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -138,6 +138,12 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont } return m.ListSchedulableByPlatforms(ctx, platforms) } +func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return m.ListSchedulableByPlatform(ctx, platform) +} +func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + return m.ListSchedulableByPlatforms(ctx, platforms) +} func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } @@ -170,6 +176,14 @@ func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64, return 0, nil } +func (m *mockAccountRepoForGemini) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return nil +} + +func (m *mockAccountRepoForGemini) ResetQuotaUsed(ctx context.Context, id int64) error { + return nil +} + // Verify interface implementation var _ AccountRepository = (*mockAccountRepoForGemini)(nil) diff --git a/backend/internal/service/gemini_native_signature_cleaner_test.go b/backend/internal/service/gemini_native_signature_cleaner_test.go new file mode 100644 index 00000000..2e184919 --- /dev/null +++ b/backend/internal/service/gemini_native_signature_cleaner_test.go @@ -0,0 +1,75 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/stretchr/testify/require" +) + +func TestCleanGeminiNativeThoughtSignatures_ReplacesNestedThoughtSignatures(t *testing.T) { + input := []byte(`{ + "contents": [ + { + "role": "user", + "parts": [{"text": "hello"}] + }, + { + "role": "model", + "parts": [ + {"text": "thinking", "thought": true, "thoughtSignature": "sig_1"}, + {"functionCall": {"name": "toolA", "args": {"k": "v"}}, "thoughtSignature": "sig_2"} + ] + } + ], + "cachedContent": { + "parts": [{"text": "cached", "thoughtSignature": "sig_3"}] + }, + "signature": "keep_me" + }`) + + cleaned := CleanGeminiNativeThoughtSignatures(input) + + var got map[string]any + require.NoError(t, json.Unmarshal(cleaned, &got)) + + require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_1"`) + require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_2"`) + require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_3"`) + require.Contains(t, string(cleaned), `"thoughtSignature":"`+antigravity.DummyThoughtSignature+`"`) + require.Contains(t, string(cleaned), `"signature":"keep_me"`) +} + +func TestCleanGeminiNativeThoughtSignatures_InvalidJSONReturnsOriginal(t *testing.T) { + input := []byte(`{"contents":[invalid-json]}`) + + cleaned := CleanGeminiNativeThoughtSignatures(input) + + require.Equal(t, input, cleaned) +} + +func TestReplaceThoughtSignaturesRecursive_OnlyReplacesTargetField(t *testing.T) { + input := map[string]any{ + "thoughtSignature": "sig_root", + "signature": "keep_signature", + "nested": []any{ + map[string]any{ + "thoughtSignature": "sig_nested", + "signature": "keep_nested_signature", + }, + }, + } + + got, ok := replaceThoughtSignaturesRecursive(input).(map[string]any) + require.True(t, ok) + require.Equal(t, antigravity.DummyThoughtSignature, got["thoughtSignature"]) + require.Equal(t, "keep_signature", got["signature"]) + + nested, ok := got["nested"].([]any) + require.True(t, ok) + nestedMap, ok := nested[0].(map[string]any) + require.True(t, ok) + require.Equal(t, antigravity.DummyThoughtSignature, nestedMap["thoughtSignature"]) + require.Equal(t, "keep_nested_signature", nestedMap["signature"]) +} diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index e866bdc3..08a74a37 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -1045,7 +1045,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR ValidateResolvedIP: true, }) if err != nil { - client = &http.Client{Timeout: 30 * time.Second} + return "", fmt.Errorf("create http client failed: %w", err) } resp, err := client.Do(req) diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 313b048f..1dab67c4 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -15,10 +15,14 @@ const ( geminiTokenCacheSkew = 5 * time.Minute ) +// GeminiTokenProvider manages access_token for Gemini OAuth accounts. type GeminiTokenProvider struct { accountRepo AccountRepository tokenCache GeminiTokenCache geminiOAuthService *GeminiOAuthService + refreshAPI *OAuthRefreshAPI + executor OAuthRefreshExecutor + refreshPolicy ProviderRefreshPolicy } func NewGeminiTokenProvider( @@ -30,9 +34,21 @@ func NewGeminiTokenProvider( accountRepo: accountRepo, tokenCache: tokenCache, geminiOAuthService: geminiOAuthService, + refreshPolicy: GeminiProviderRefreshPolicy(), } } +// SetRefreshAPI injects unified OAuth refresh API and executor. +func (p *GeminiTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) { + p.refreshAPI = api + p.executor = executor +} + +// SetRefreshPolicy injects caller-side refresh policy. +func (p *GeminiTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) { + p.refreshPolicy = policy +} + func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { if account == nil { return "", errors.New("account is nil") @@ -53,39 +69,31 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou // 2) Refresh if needed (pre-expiry skew). expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew - if needsRefresh && p.tokenCache != nil { - locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) - if err == nil && locked { - defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() - // Re-check after lock (another worker may have refreshed). - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { - return token, nil + if needsRefresh && p.refreshAPI != nil && p.executor != nil { + result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, geminiTokenRefreshSkew) + if err != nil { + if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { + return "", err } - - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh + } else if result.LockHeld { + if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil { + if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" { + return token, nil + } } + slog.Debug("gemini_token_lock_held_use_old", "account_id", account.ID) + } else { + account = result.Account expiresAt = account.GetCredentialAsTime("expires_at") - if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew { - if p.geminiOAuthService == nil { - return "", errors.New("gemini oauth service not configured") - } - tokenInfo, err := p.geminiOAuthService.RefreshAccountToken(ctx, account) - if err != nil { - return "", err - } - newCredentials := p.geminiOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } - account.Credentials = newCredentials - _ = p.accountRepo.Update(ctx, account) - expiresAt = account.GetCredentialAsTime("expires_at") - } + } + } else if needsRefresh && p.tokenCache != nil { + // Backward-compatible test path when refreshAPI is not injected. + locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if lockErr == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + } else if lockErr != nil { + slog.Warn("gemini_token_lock_failed", "account_id", account.ID, "error", lockErr) } } @@ -95,15 +103,14 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } // project_id is optional now: - // - If present: will use Code Assist API (requires project_id) - // - If absent: will use AI Studio API with OAuth token (like regular API key mode) - // Auto-detect project_id only if explicitly enabled via a credential flag + // - If present: use Code Assist API (requires project_id) + // - If absent: use AI Studio API with OAuth token. projectID := strings.TrimSpace(account.GetCredential("project_id")) autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true" if projectID == "" && autoDetectProjectID { if p.geminiOAuthService == nil { - return accessToken, nil // Fallback to AI Studio API mode + return accessToken, nil } var proxyURL string @@ -132,17 +139,15 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } } - // 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) + // 3) Populate cache with TTL. if p.tokenCache != nil { latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) if isStale && latestAccount != nil { - // 版本过时,使用 DB 中的最新 token slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID) accessToken = latestAccount.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found after version check") } - // 不写入缓存,让下次请求重新处理 } else { ttl := 30 * time.Minute if expiresAt != nil { diff --git a/backend/internal/service/gemini_token_refresher.go b/backend/internal/service/gemini_token_refresher.go index 7dfc5521..d5e502da 100644 --- a/backend/internal/service/gemini_token_refresher.go +++ b/backend/internal/service/gemini_token_refresher.go @@ -13,6 +13,11 @@ func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiToke return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService} } +// CacheKey 返回用于分布式锁的缓存键 +func (r *GeminiTokenRefresher) CacheKey(account *Account) string { + return GeminiTokenCacheKey(account) +} + func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool { return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth } @@ -35,11 +40,7 @@ func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (m } newCredentials := r.geminiOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } + newCredentials = MergeCredentials(account.Credentials, newCredentials) return newCredentials, nil } diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 6990caca..537b5a3b 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -57,6 +57,10 @@ type Group struct { // 分组排序 SortOrder int + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch bool + DefaultMappedModel string + CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index f3130c91..f6a94d15 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -19,8 +19,10 @@ import ( // 预编译正则表达式(避免每次调用重新编译) var ( - // 匹配 user_id 格式: user_{64位hex}_account__session_{uuid} - userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`) + // 匹配 user_id 格式: + // 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID) + // 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID) + userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`) // 匹配 User-Agent 版本号: xxx/x.y.z userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`) ) @@ -239,13 +241,16 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI return body, nil } - // 匹配格式: user_{64位hex}_account__session_{uuid} + // 匹配格式: + // 旧格式: user_{64位hex}_account__session_{uuid} + // 新格式: user_{64位hex}_account_{uuid}_session_{uuid} matches := userIDRegex.FindStringSubmatch(userID) if matches == nil { return body, nil } - sessionTail := matches[1] // 原始session UUID + // matches[1] = account UUID (可能为空), matches[2] = session UUID + sessionTail := matches[2] // 原始session UUID // 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式 seed := fmt.Sprintf("%d::%s", accountID, sessionTail) diff --git a/backend/internal/service/oauth_refresh_api.go b/backend/internal/service/oauth_refresh_api.go new file mode 100644 index 00000000..17b9128c --- /dev/null +++ b/backend/internal/service/oauth_refresh_api.go @@ -0,0 +1,159 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "strconv" + "time" +) + +// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器 +// TokenRefresher 接口的超集:增加了 CacheKey 方法用于分布式锁 +type OAuthRefreshExecutor interface { + TokenRefresher + + // CacheKey 返回用于分布式锁的缓存键(与 TokenProvider 使用的一致) + CacheKey(account *Account) string +} + +const refreshLockTTL = 30 * time.Second + +// OAuthRefreshResult 统一刷新结果 +type OAuthRefreshResult struct { + Refreshed bool // 实际执行了刷新 + NewCredentials map[string]any // 刷新后的 credentials(nil 表示未刷新) + Account *Account // 从 DB 重新读取的最新 account + LockHeld bool // 锁被其他 worker 持有(未执行刷新) +} + +// OAuthRefreshAPI 统一的 OAuth Token 刷新入口 +// 封装分布式锁、DB 重读、已刷新检查等通用逻辑 +type OAuthRefreshAPI struct { + accountRepo AccountRepository + tokenCache GeminiTokenCache // 可选,nil = 无锁 +} + +// NewOAuthRefreshAPI 创建统一刷新 API +func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI { + return &OAuthRefreshAPI{ + accountRepo: accountRepo, + tokenCache: tokenCache, + } +} + +// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token +// +// 流程: +// 1. 获取分布式锁 +// 2. 从 DB 重读最新 account(防止使用过时的 refresh_token) +// 3. 二次检查是否仍需刷新 +// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑 +// 5. 设置 _token_version + 更新 DB +// 6. 释放锁 +func (api *OAuthRefreshAPI) RefreshIfNeeded( + ctx context.Context, + account *Account, + executor OAuthRefreshExecutor, + refreshWindow time.Duration, +) (*OAuthRefreshResult, error) { + cacheKey := executor.CacheKey(account) + + // 1. 获取分布式锁 + lockAcquired := false + if api.tokenCache != nil { + acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL) + if lockErr != nil { + // Redis 错误,降级为无锁刷新 + slog.Warn("oauth_refresh_lock_failed_degraded", + "account_id", account.ID, + "cache_key", cacheKey, + "error", lockErr, + ) + } else if !acquired { + // 锁被其他 worker 持有 + return &OAuthRefreshResult{LockHeld: true}, nil + } else { + lockAcquired = true + defer func() { _ = api.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + } + } + + // 2. 从 DB 重读最新 account(锁保护下,确保使用最新的 refresh_token) + freshAccount, err := api.accountRepo.GetByID(ctx, account.ID) + if err != nil { + slog.Warn("oauth_refresh_db_reread_failed", + "account_id", account.ID, + "error", err, + ) + // 降级使用传入的 account + freshAccount = account + } else if freshAccount == nil { + freshAccount = account + } + + // 3. 二次检查是否仍需刷新(另一条路径可能已刷新) + if !executor.NeedsRefresh(freshAccount, refreshWindow) { + return &OAuthRefreshResult{ + Account: freshAccount, + }, nil + } + + // 4. 执行平台特定刷新逻辑 + newCredentials, refreshErr := executor.Refresh(ctx, freshAccount) + if refreshErr != nil { + return nil, refreshErr + } + + // 5. 设置版本号 + 更新 DB + if newCredentials != nil { + newCredentials["_token_version"] = time.Now().UnixMilli() + freshAccount.Credentials = newCredentials + if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil { + slog.Error("oauth_refresh_update_failed", + "account_id", freshAccount.ID, + "error", updateErr, + ) + return nil, fmt.Errorf("oauth refresh succeeded but DB update failed: %w", updateErr) + } + } + + _ = lockAcquired // suppress unused warning when tokenCache is nil + + return &OAuthRefreshResult{ + Refreshed: true, + NewCredentials: newCredentials, + Account: freshAccount, + }, nil +} + +// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中 +func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any { + if newCreds == nil { + newCreds = make(map[string]any) + } + for k, v := range oldCreds { + if _, exists := newCreds[k]; !exists { + newCreds[k] = v + } + } + return newCreds +} + +// BuildClaudeAccountCredentials 为 Claude 平台构建 OAuth credentials map +// 消除 Claude 平台没有 BuildAccountCredentials 方法的问题 +func BuildClaudeAccountCredentials(tokenInfo *TokenInfo) map[string]any { + creds := map[string]any{ + "access_token": tokenInfo.AccessToken, + "token_type": tokenInfo.TokenType, + "expires_in": strconv.FormatInt(tokenInfo.ExpiresIn, 10), + "expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10), + } + if tokenInfo.RefreshToken != "" { + creds["refresh_token"] = tokenInfo.RefreshToken + } + if tokenInfo.Scope != "" { + creds["scope"] = tokenInfo.Scope + } + return creds +} diff --git a/backend/internal/service/oauth_refresh_api_test.go b/backend/internal/service/oauth_refresh_api_test.go new file mode 100644 index 00000000..6cf9371f --- /dev/null +++ b/backend/internal/service/oauth_refresh_api_test.go @@ -0,0 +1,395 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// ---------- mock helpers ---------- + +// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests. +type refreshAPIAccountRepo struct { + mockAccountRepoForGemini + account *Account // returned by GetByID + getByIDErr error + updateErr error + updateCalls int +} + +func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) { + if r.getByIDErr != nil { + return nil, r.getByIDErr + } + return r.account, nil +} + +func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error { + r.updateCalls++ + return r.updateErr +} + +// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests. +type refreshAPIExecutorStub struct { + needsRefresh bool + credentials map[string]any + err error + refreshCalls int +} + +func (e *refreshAPIExecutorStub) CanRefresh(_ *Account) bool { return true } + +func (e *refreshAPIExecutorStub) NeedsRefresh(_ *Account, _ time.Duration) bool { + return e.needsRefresh +} + +func (e *refreshAPIExecutorStub) Refresh(_ context.Context, _ *Account) (map[string]any, error) { + e.refreshCalls++ + if e.err != nil { + return nil, e.err + } + return e.credentials, nil +} + +func (e *refreshAPIExecutorStub) CacheKey(account *Account) string { + return "test:api:" + account.Platform +} + +// refreshAPICacheStub implements GeminiTokenCache for OAuthRefreshAPI tests. +type refreshAPICacheStub struct { + lockResult bool + lockErr error + releaseCalls int +} + +func (c *refreshAPICacheStub) GetAccessToken(context.Context, string) (string, error) { + return "", nil +} + +func (c *refreshAPICacheStub) SetAccessToken(context.Context, string, string, time.Duration) error { + return nil +} + +func (c *refreshAPICacheStub) DeleteAccessToken(context.Context, string) error { return nil } + +func (c *refreshAPICacheStub) AcquireRefreshLock(context.Context, string, time.Duration) (bool, error) { + return c.lockResult, c.lockErr +} + +func (c *refreshAPICacheStub) ReleaseRefreshLock(context.Context, string) error { + c.releaseCalls++ + return nil +} + +// ========== RefreshIfNeeded tests ========== + +func TestRefreshIfNeeded_Success(t *testing.T) { + account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "new-token"}, + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) + require.NotNil(t, result.NewCredentials) + require.Equal(t, "new-token", result.NewCredentials["access_token"]) + require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set + require.Equal(t, 1, repo.updateCalls) // DB updated + require.Equal(t, 1, cache.releaseCalls) // lock released + require.Equal(t, 1, executor.refreshCalls) +} + +func TestRefreshIfNeeded_LockHeld(t *testing.T) { + account := &Account{ID: 2, Platform: PlatformAnthropic} + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: false} // lock not acquired + executor := &refreshAPIExecutorStub{needsRefresh: true} + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.LockHeld) + require.False(t, result.Refreshed) + require.Equal(t, 0, repo.updateCalls) + require.Equal(t, 0, executor.refreshCalls) +} + +func TestRefreshIfNeeded_LockErrorDegrades(t *testing.T) { + account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeOAuth} + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockErr: errors.New("redis down")} // lock error + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "degraded-token"}, + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) // still refreshed (degraded mode) + require.Equal(t, 1, repo.updateCalls) // DB updated + require.Equal(t, 0, cache.releaseCalls) // no lock to release + require.Equal(t, 1, executor.refreshCalls) +} + +func TestRefreshIfNeeded_NoCacheNoLock(t *testing.T) { + account := &Account{ID: 4, Platform: PlatformGemini, Type: AccountTypeOAuth} + repo := &refreshAPIAccountRepo{account: account} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "no-cache-token"}, + } + + api := NewOAuthRefreshAPI(repo, nil) // no cache = no lock + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) + require.Equal(t, 1, repo.updateCalls) +} + +func TestRefreshIfNeeded_AlreadyRefreshed(t *testing.T) { + account := &Account{ID: 5, Platform: PlatformAnthropic} + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{needsRefresh: false} // already refreshed + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.False(t, result.Refreshed) + require.False(t, result.LockHeld) + require.NotNil(t, result.Account) // returns fresh account + require.Equal(t, 0, repo.updateCalls) + require.Equal(t, 0, executor.refreshCalls) +} + +func TestRefreshIfNeeded_RefreshError(t *testing.T) { + account := &Account{ID: 6, Platform: PlatformAnthropic} + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + err: errors.New("invalid_grant: token revoked"), + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "invalid_grant") + require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error + require.Equal(t, 1, cache.releaseCalls) // lock still released via defer +} + +func TestRefreshIfNeeded_DBUpdateError(t *testing.T) { + account := &Account{ID: 7, Platform: PlatformGemini, Type: AccountTypeOAuth} + repo := &refreshAPIAccountRepo{ + account: account, + updateErr: errors.New("db connection lost"), + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "token"}, + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "DB update failed") + require.Equal(t, 1, repo.updateCalls) // attempted +} + +func TestRefreshIfNeeded_DBRereadFails(t *testing.T) { + account := &Account{ID: 8, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + repo := &refreshAPIAccountRepo{ + account: nil, // GetByID returns nil + getByIDErr: errors.New("db timeout"), + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "fallback-token"}, + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) + require.Equal(t, 1, executor.refreshCalls) // still refreshes using passed-in account +} + +func TestRefreshIfNeeded_NilCredentials(t *testing.T) { + account := &Account{ID: 9, Platform: PlatformGemini, Type: AccountTypeOAuth} + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: nil, // Refresh returns nil credentials + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) + require.Nil(t, result.NewCredentials) + require.Equal(t, 0, repo.updateCalls) // no DB update when credentials are nil +} + +// ========== MergeCredentials tests ========== + +func TestMergeCredentials_Basic(t *testing.T) { + old := map[string]any{"a": "1", "b": "2", "c": "3"} + new := map[string]any{"a": "new", "d": "4"} + + result := MergeCredentials(old, new) + + require.Equal(t, "new", result["a"]) // new value preserved + require.Equal(t, "2", result["b"]) // old value kept + require.Equal(t, "3", result["c"]) // old value kept + require.Equal(t, "4", result["d"]) // new value preserved +} + +func TestMergeCredentials_NilNew(t *testing.T) { + old := map[string]any{"a": "1"} + + result := MergeCredentials(old, nil) + + require.NotNil(t, result) + require.Equal(t, "1", result["a"]) +} + +func TestMergeCredentials_NilOld(t *testing.T) { + new := map[string]any{"a": "1"} + + result := MergeCredentials(nil, new) + + require.Equal(t, "1", result["a"]) +} + +func TestMergeCredentials_BothNil(t *testing.T) { + result := MergeCredentials(nil, nil) + require.NotNil(t, result) + require.Empty(t, result) +} + +func TestMergeCredentials_NewOverridesOld(t *testing.T) { + old := map[string]any{"access_token": "old-token", "refresh_token": "old-refresh"} + new := map[string]any{"access_token": "new-token"} + + result := MergeCredentials(old, new) + + require.Equal(t, "new-token", result["access_token"]) // overridden + require.Equal(t, "old-refresh", result["refresh_token"]) // preserved +} + +// ========== BuildClaudeAccountCredentials tests ========== + +func TestBuildClaudeAccountCredentials_Full(t *testing.T) { + tokenInfo := &TokenInfo{ + AccessToken: "at-123", + TokenType: "Bearer", + ExpiresIn: 3600, + ExpiresAt: 1700000000, + RefreshToken: "rt-456", + Scope: "openid", + } + + creds := BuildClaudeAccountCredentials(tokenInfo) + + require.Equal(t, "at-123", creds["access_token"]) + require.Equal(t, "Bearer", creds["token_type"]) + require.Equal(t, "3600", creds["expires_in"]) + require.Equal(t, "1700000000", creds["expires_at"]) + require.Equal(t, "rt-456", creds["refresh_token"]) + require.Equal(t, "openid", creds["scope"]) +} + +func TestBuildClaudeAccountCredentials_Minimal(t *testing.T) { + tokenInfo := &TokenInfo{ + AccessToken: "at-789", + TokenType: "Bearer", + ExpiresIn: 7200, + ExpiresAt: 1700003600, + } + + creds := BuildClaudeAccountCredentials(tokenInfo) + + require.Equal(t, "at-789", creds["access_token"]) + require.Equal(t, "Bearer", creds["token_type"]) + require.Equal(t, "7200", creds["expires_in"]) + require.Equal(t, "1700003600", creds["expires_at"]) + _, hasRefresh := creds["refresh_token"] + _, hasScope := creds["scope"] + require.False(t, hasRefresh, "refresh_token should not be set when empty") + require.False(t, hasScope, "scope should not be set when empty") +} + +// ========== BackgroundRefreshPolicy tests ========== + +func TestBackgroundRefreshPolicy_DefaultSkips(t *testing.T) { + p := DefaultBackgroundRefreshPolicy() + + require.ErrorIs(t, p.handleLockHeld(), errRefreshSkipped) + require.ErrorIs(t, p.handleAlreadyRefreshed(), errRefreshSkipped) +} + +func TestBackgroundRefreshPolicy_SuccessOverride(t *testing.T) { + p := BackgroundRefreshPolicy{ + OnLockHeld: BackgroundSkipAsSuccess, + OnAlreadyRefresh: BackgroundSkipAsSuccess, + } + + require.NoError(t, p.handleLockHeld()) + require.NoError(t, p.handleAlreadyRefreshed()) +} + +// ========== ProviderRefreshPolicy tests ========== + +func TestClaudeProviderRefreshPolicy(t *testing.T) { + p := ClaudeProviderRefreshPolicy() + require.Equal(t, ProviderRefreshErrorUseExistingToken, p.OnRefreshError) + require.Equal(t, ProviderLockHeldWaitForCache, p.OnLockHeld) + require.Equal(t, time.Minute, p.FailureTTL) +} + +func TestOpenAIProviderRefreshPolicy(t *testing.T) { + p := OpenAIProviderRefreshPolicy() + require.Equal(t, ProviderRefreshErrorUseExistingToken, p.OnRefreshError) + require.Equal(t, ProviderLockHeldWaitForCache, p.OnLockHeld) + require.Equal(t, time.Minute, p.FailureTTL) +} + +func TestGeminiProviderRefreshPolicy(t *testing.T) { + p := GeminiProviderRefreshPolicy() + require.Equal(t, ProviderRefreshErrorReturn, p.OnRefreshError) + require.Equal(t, ProviderLockHeldUseExistingToken, p.OnLockHeld) + require.Equal(t, time.Duration(0), p.FailureTTL) +} + +func TestAntigravityProviderRefreshPolicy(t *testing.T) { + p := AntigravityProviderRefreshPolicy() + require.Equal(t, ProviderRefreshErrorReturn, p.OnRefreshError) + require.Equal(t, ProviderLockHeldUseExistingToken, p.OnLockHeld) + require.Equal(t, time.Duration(0), p.FailureTTL) +} diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 99013ce5..789888cb 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -319,7 +319,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } - if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() { + if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() || !account.IsSchedulable() { _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } @@ -342,6 +342,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( } cfg := s.service.schedulingConfig() + // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 if s.service.concurrencyService != nil { return &AccountSelectionResult{ Account: account, @@ -590,7 +591,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( filtered = append(filtered, account) loadReq = append(loadReq, AccountWithConcurrency{ ID: account.ID, - MaxConcurrency: account.Concurrency, + MaxConcurrency: account.EffectiveLoadFactor(), }) } if len(filtered) == 0 { @@ -686,16 +687,20 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( for i := 0; i < len(selectionOrder); i++ { candidate := selectionOrder[i] - result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency) + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + continue + } + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if acquireErr != nil { return nil, len(candidates), topK, loadSkew, acquireErr } if result != nil && result.Acquired { if req.SessionHash != "" { - _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID) + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID) } return &AccountSelectionResult{ - Account: candidate.account, + Account: fresh, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, len(candidates), topK, loadSkew, nil @@ -703,16 +708,24 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( } cfg := s.service.schedulingConfig() - candidate := selectionOrder[0] - return &AccountSelectionResult{ - Account: candidate.account, - WaitPlan: &AccountWaitPlan{ - AccountID: candidate.account.ID, - MaxConcurrency: candidate.account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, len(candidates), topK, loadSkew, nil + // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 + for _, candidate := range selectionOrder { + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + continue + } + return &AccountSelectionResult{ + Account: fresh, + WaitPlan: &AccountWaitPlan{ + AccountID: fresh.ID, + MaxConcurrency: fresh.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, len(candidates), topK, loadSkew, nil + } + + return nil, len(candidates), topK, loadSkew, ErrNoAvailableAccounts } func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 7f6f1b66..977c4ee8 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -12,6 +12,78 @@ import ( "github.com/stretchr/testify/require" ) +type openAISnapshotCacheStub struct { + SchedulerCache + snapshotAccounts []*Account + accountsByID map[int64]*Account +} + +func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) { + if len(s.snapshotAccounts) == 0 { + return nil, false, nil + } + out := make([]*Account, 0, len(s.snapshotAccounts)) + for _, account := range s.snapshotAccounts { + if account == nil { + continue + } + cloned := *account + out = append(out, &cloned) + } + return out, true, nil +} + +func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int64) (*Account, error) { + if s.accountsByID == nil { + return nil, nil + } + account := s.accountsByID[accountID] + if account == nil { + return nil, nil + } + cloned := *account + return &cloned, nil +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) { + ctx := context.Background() + groupID := int64(10101) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + staleSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0} + staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} + freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} + snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}} + snapshotService := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})} + + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(31002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + +func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRateLimitedSnapshotCandidate(t *testing.T) { + ctx := context.Background() + groupID := int64(10102) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + stalePrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0} + staleSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + freshPrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} + freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}} + snapshotService := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService} + + account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(32002), account.ID) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { ctx := context.Background() groupID := int64(9) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 16befb82..29f2b672 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -1,14 +1,18 @@ package service import ( - _ "embed" + "fmt" "strings" ) -//go:embed prompts/codex_cli_instructions.md -var codexCLIInstructions string - var codexModelMap = map[string]string{ + "gpt-5.4": "gpt-5.4", + "gpt-5.4-none": "gpt-5.4", + "gpt-5.4-low": "gpt-5.4", + "gpt-5.4-medium": "gpt-5.4", + "gpt-5.4-high": "gpt-5.4", + "gpt-5.4-xhigh": "gpt-5.4", + "gpt-5.4-chat-latest": "gpt-5.4", "gpt-5.3": "gpt-5.3-codex", "gpt-5.3-none": "gpt-5.3-codex", "gpt-5.3-low": "gpt-5.3-codex", @@ -70,7 +74,7 @@ type codexTransformResult struct { PromptCacheKey string } -func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult { +func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult { result := codexTransformResult{} // 工具续链需求会影响存储策略与 input 过滤逻辑。 needsToolContinuation := NeedsToolContinuation(reqBody) @@ -88,15 +92,26 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran result.NormalizedModel = normalizedModel } - // OAuth 走 ChatGPT internal API 时,store 必须为 false;显式 true 也会强制覆盖。 - // 避免上游返回 "Store must be set to false"。 - if v, ok := reqBody["store"].(bool); !ok || v { - reqBody["store"] = false - result.Modified = true - } - if v, ok := reqBody["stream"].(bool); !ok || !v { - reqBody["stream"] = true - result.Modified = true + if isCompact { + if _, ok := reqBody["store"]; ok { + delete(reqBody, "store") + result.Modified = true + } + if _, ok := reqBody["stream"]; ok { + delete(reqBody, "stream") + result.Modified = true + } + } else { + // OAuth 走 ChatGPT internal API 时,store 必须为 false;显式 true 也会强制覆盖。 + // 避免上游返回 "Store must be set to false"。 + if v, ok := reqBody["store"].(bool); !ok || v { + reqBody["store"] = false + result.Modified = true + } + if v, ok := reqBody["stream"].(bool); !ok || !v { + reqBody["stream"] = true + result.Modified = true + } } // Strip parameters unsupported by codex models via the Responses API. @@ -114,6 +129,41 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran } } + // 兼容遗留的 functions 和 function_call,转换为 tools 和 tool_choice + if functionsRaw, ok := reqBody["functions"]; ok { + if functions, k := functionsRaw.([]any); k { + tools := make([]any, 0, len(functions)) + for _, f := range functions { + tools = append(tools, map[string]any{ + "type": "function", + "function": f, + }) + } + reqBody["tools"] = tools + } + delete(reqBody, "functions") + result.Modified = true + } + + if fcRaw, ok := reqBody["function_call"]; ok { + if fcStr, ok := fcRaw.(string); ok { + // e.g. "auto", "none" + reqBody["tool_choice"] = fcStr + } else if fcObj, ok := fcRaw.(map[string]any); ok { + // e.g. {"name": "my_func"} + if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" { + reqBody["tool_choice"] = map[string]any{ + "type": "function", + "function": map[string]any{ + "name": name, + }, + } + } + } + delete(reqBody, "function_call") + result.Modified = true + } + if normalizeCodexTools(reqBody) { result.Modified = true } @@ -132,6 +182,22 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran input = filterCodexInput(input, needsToolContinuation) reqBody["input"] = input result.Modified = true + } else if inputStr, ok := reqBody["input"].(string); ok { + // ChatGPT codex endpoint requires input to be a list, not a string. + // Convert string input to the expected message array format. + trimmed := strings.TrimSpace(inputStr) + if trimmed != "" { + reqBody["input"] = []any{ + map[string]any{ + "type": "message", + "role": "user", + "content": inputStr, + }, + } + } else { + reqBody["input"] = []any{} + } + result.Modified = true } return result @@ -154,6 +220,9 @@ func normalizeCodexModel(model string) string { normalized := strings.ToLower(modelID) + if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") { + return "gpt-5.4" + } if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") { return "gpt-5.2-codex" } @@ -193,6 +262,29 @@ func normalizeCodexModel(model string) string { return "gpt-5.1" } +func SupportsVerbosity(model string) bool { + if !strings.HasPrefix(model, "gpt-") { + return true + } + + var major, minor int + n, _ := fmt.Sscanf(model, "gpt-%d.%d", &major, &minor) + + if major > 5 { + return true + } + if major < 5 { + return false + } + + // gpt-5 + if n == 1 { + return true + } + + return minor >= 3 +} + func getNormalizedCodexModel(modelID string) string { if modelID == "" { return "" @@ -209,72 +301,13 @@ func getNormalizedCodexModel(modelID string) string { return "" } -func getOpenCodeCodexHeader() string { - // 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。 - // 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。 - return getCodexCLIInstructions() -} - -func getCodexCLIInstructions() string { - return codexCLIInstructions -} - -func GetOpenCodeInstructions() string { - return getOpenCodeCodexHeader() -} - -// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。 -func GetCodexCLIInstructions() string { - return getCodexCLIInstructions() -} - -// applyInstructions 处理 instructions 字段 -// isCodexCLI=true: 仅补充缺失的 instructions(使用内置 Codex CLI 指令) -// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖 +// applyInstructions 处理 instructions 字段:仅在 instructions 为空时填充默认值。 func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { - if isCodexCLI { - return applyCodexCLIInstructions(reqBody) - } - return applyOpenCodeInstructions(reqBody) -} - -// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions -// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源) -func applyCodexCLIInstructions(reqBody map[string]any) bool { if !isInstructionsEmpty(reqBody) { - return false // 已有有效 instructions,不修改 + return false } - - instructions := strings.TrimSpace(getCodexCLIInstructions()) - if instructions != "" { - reqBody["instructions"] = instructions - return true - } - - return false -} - -// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名) -// 优先使用内置 Codex CLI 指令覆盖 -func applyOpenCodeInstructions(reqBody map[string]any) bool { - instructions := strings.TrimSpace(getOpenCodeCodexHeader()) - existingInstructions, _ := reqBody["instructions"].(string) - existingInstructions = strings.TrimSpace(existingInstructions) - - if instructions != "" { - if existingInstructions != instructions { - reqBody["instructions"] = instructions - return true - } - } else if existingInstructions == "" { - codexInstructions := strings.TrimSpace(getCodexCLIInstructions()) - if codexInstructions != "" { - reqBody["instructions"] = codexInstructions - return true - } - } - - return false + reqBody["instructions"] = "You are a helpful coding assistant." + return true } // isInstructionsEmpty 检查 instructions 字段是否为空 @@ -305,6 +338,19 @@ func filterCodexInput(input []any, preserveReferences bool) []any { continue } typ, _ := m["type"].(string) + + // 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id; + // 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。 + fixCallIDPrefix := func(id string) string { + if id == "" || strings.HasPrefix(id, "fc") { + return id + } + if strings.HasPrefix(id, "call_") { + return "fc" + strings.TrimPrefix(id, "call_") + } + return "fc_" + id + } + if typ == "item_reference" { if !preserveReferences { continue @@ -313,6 +359,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any { for key, value := range m { newItem[key] = value } + if id, ok := newItem["id"].(string); ok && strings.HasPrefix(id, "call_") { + newItem["id"] = fixCallIDPrefix(id) + } filtered = append(filtered, newItem) continue } @@ -332,10 +381,20 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } if isCodexToolCallItemType(typ) { - if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" { + callID, ok := m["call_id"].(string) + if !ok || strings.TrimSpace(callID) == "" { if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" { + callID = id ensureCopy() - newItem["call_id"] = id + newItem["call_id"] = callID + } + } + + if callID != "" { + fixedCallID := fixCallIDPrefix(callID) + if fixedCallID != callID { + ensureCopy() + newItem["call_id"] = fixedCallID } } } diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 27093f6c..ae6f8555 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -18,7 +18,7 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { "tool_choice": "auto", } - applyCodexOAuthTransform(reqBody, false) + applyCodexOAuthTransform(reqBody, false, false) // 未显式设置 store=true,默认为 false。 store, ok := reqBody["store"].(bool) @@ -39,6 +39,57 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { second, ok := input[1].(map[string]any) require.True(t, ok) require.Equal(t, "o1", second["id"]) + require.Equal(t, "fc1", second["call_id"]) +} + +func TestApplyCodexOAuthTransform_ToolContinuationPreservesNativeMessageAndReasoningIDs(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.2", + "input": []any{ + map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"}, + map[string]any{"type": "item_reference", "id": "rs_123"}, + }, + "tool_choice": "auto", + } + + applyCodexOAuthTransform(reqBody, false, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + + first, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "msg_0", first["id"]) + + second, ok := input[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "rs_123", second["id"]) +} + +func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.2", + "input": []any{ + map[string]any{"type": "item_reference", "id": "call_1"}, + map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok"}, + }, + "tool_choice": "auto", + } + + applyCodexOAuthTransform(reqBody, false, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + + first, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "fc1", first["id"]) + + second, ok := input[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "fc1", second["call_id"]) } func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { @@ -53,7 +104,7 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { "tool_choice": "auto", } - applyCodexOAuthTransform(reqBody, false) + applyCodexOAuthTransform(reqBody, false, false) store, ok := reqBody["store"].(bool) require.True(t, ok) @@ -72,13 +123,29 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { "tool_choice": "auto", } - applyCodexOAuthTransform(reqBody, false) + applyCodexOAuthTransform(reqBody, false, false) store, ok := reqBody["store"].(bool) require.True(t, ok) require.False(t, store) } +func TestApplyCodexOAuthTransform_CompactForcesNonStreaming(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.1-codex", + "store": true, + "stream": true, + } + + result := applyCodexOAuthTransform(reqBody, true, true) + + _, hasStore := reqBody["store"] + require.False(t, hasStore) + _, hasStream := reqBody["stream"] + require.False(t, hasStream) + require.True(t, result.Modified) +} + func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) { // 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。 @@ -89,7 +156,7 @@ func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs( }, } - applyCodexOAuthTransform(reqBody, false) + applyCodexOAuthTransform(reqBody, false, false) store, ok := reqBody["store"].(bool) require.True(t, ok) @@ -138,7 +205,7 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction }, } - applyCodexOAuthTransform(reqBody, false) + applyCodexOAuthTransform(reqBody, false, false) tools, ok := reqBody["tools"].([]any) require.True(t, ok) @@ -158,7 +225,7 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { "input": []any{}, } - applyCodexOAuthTransform(reqBody, false) + applyCodexOAuthTransform(reqBody, false, false) input, ok := reqBody["input"].([]any) require.True(t, ok) @@ -167,6 +234,10 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { func TestNormalizeCodexModel_Gpt53(t *testing.T) { cases := map[string]string{ + "gpt-5.4": "gpt-5.4", + "gpt-5.4-high": "gpt-5.4", + "gpt-5.4-chat-latest": "gpt-5.4", + "gpt 5.4": "gpt-5.4", "gpt-5.3": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", @@ -189,7 +260,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *test "instructions": "existing instructions", } - result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true + result := applyCodexOAuthTransform(reqBody, true, false) // isCodexCLI=true instructions, ok := reqBody["instructions"].(string) require.True(t, ok) @@ -206,7 +277,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T // 没有 instructions 字段 } - result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true + result := applyCodexOAuthTransform(reqBody, true, false) // isCodexCLI=true instructions, ok := reqBody["instructions"].(string) require.True(t, ok) @@ -214,20 +285,63 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T require.True(t, result.Modified) } -func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) { - // 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖 +func TestApplyCodexOAuthTransform_NonCodexCLI_PreservesExistingInstructions(t *testing.T) { + // 非 Codex CLI 场景:已有 instructions 时保留客户端的值,不再覆盖 reqBody := map[string]any{ "model": "gpt-5.1", "instructions": "old instructions", } - result := applyCodexOAuthTransform(reqBody, false) // isCodexCLI=false + applyCodexOAuthTransform(reqBody, false, false) // isCodexCLI=false instructions, ok := reqBody["instructions"].(string) require.True(t, ok) - require.NotEqual(t, "old instructions", instructions) + require.Equal(t, "old instructions", instructions) +} + +func TestApplyCodexOAuthTransform_StringInputConvertedToArray(t *testing.T) { + reqBody := map[string]any{"model": "gpt-5.4", "input": "Hello, world!"} + result := applyCodexOAuthTransform(reqBody, false, false) require.True(t, result.Modified) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + msg, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "message", msg["type"]) + require.Equal(t, "user", msg["role"]) + require.Equal(t, "Hello, world!", msg["content"]) +} + +func TestApplyCodexOAuthTransform_EmptyStringInputBecomesEmptyArray(t *testing.T) { + reqBody := map[string]any{"model": "gpt-5.4", "input": ""} + result := applyCodexOAuthTransform(reqBody, false, false) + require.True(t, result.Modified) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 0) +} + +func TestApplyCodexOAuthTransform_WhitespaceStringInputBecomesEmptyArray(t *testing.T) { + reqBody := map[string]any{"model": "gpt-5.4", "input": " "} + result := applyCodexOAuthTransform(reqBody, false, false) + require.True(t, result.Modified) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 0) +} + +func TestApplyCodexOAuthTransform_StringInputWithToolsField(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": "Run the tests", + "tools": []any{map[string]any{"type": "function", "name": "bash"}}, + } + applyCodexOAuthTransform(reqBody, false, false) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) } func TestIsInstructionsEmpty(t *testing.T) { diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go new file mode 100644 index 00000000..9529f6be --- /dev/null +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -0,0 +1,509 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// ForwardAsChatCompletions accepts a Chat Completions request body, converts it +// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts +// the response back to Chat Completions format. All account types (OAuth and API +// Key) go through the Responses API conversion path since the upstream only +// exposes the /v1/responses endpoint. +func (s *OpenAIGatewayService) ForwardAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + promptCacheKey string, + defaultMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + + // 1. Parse Chat Completions request + var chatReq apicompat.ChatCompletionsRequest + if err := json.Unmarshal(body, &chatReq); err != nil { + return nil, fmt.Errorf("parse chat completions request: %w", err) + } + originalModel := chatReq.Model + clientStream := chatReq.Stream + includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage + + // 2. Convert to Responses and forward + // ChatCompletionsToResponses always sets Stream=true (upstream always streams). + responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq) + if err != nil { + return nil, fmt.Errorf("convert chat completions to responses: %w", err) + } + + // 3. Model mapping + mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + responsesReq.Model = mappedModel + + logger.L().Debug("openai chat_completions: model mapping applied", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("mapped_model", mappedModel), + zap.Bool("stream", clientStream), + ) + + // 4. Marshal Responses request body, then apply OAuth codex transform + responsesBody, err := json.Marshal(responsesReq) + if err != nil { + return nil, fmt.Errorf("marshal responses request: %w", err) + } + + if account.Type == AccountTypeOAuth { + var reqBody map[string]any + if err := json.Unmarshal(responsesBody, &reqBody); err != nil { + return nil, fmt.Errorf("unmarshal for codex transform: %w", err) + } + codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.PromptCacheKey != "" { + promptCacheKey = codexResult.PromptCacheKey + } else if promptCacheKey != "" { + reqBody["prompt_cache_key"] = promptCacheKey + } + responsesBody, err = json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("remarshal after codex transform: %w", err) + } + } + + // 5. Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 6. Build upstream request + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false) + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + if promptCacheKey != "" { + upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey)) + } + + // 7. Send request + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 8. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } + } + return s.handleChatCompletionsErrorResponse(resp, c, account) + } + + // 9. Handle normal response + var result *OpenAIForwardResult + var handleErr error + if clientStream { + result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime) + } else { + result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + } + + // Propagate ServiceTier and ReasoningEffort to result for billing + if handleErr == nil && result != nil { + if responsesReq.ServiceTier != "" { + st := responsesReq.ServiceTier + result.ServiceTier = &st + } + if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" { + re := responsesReq.Reasoning.Effort + result.ReasoningEffort = &re + } + } + + // Extract and save Codex usage snapshot from response headers (for OAuth accounts) + if handleErr == nil && account.Type == AccountTypeOAuth { + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } + } + + return result, handleErr +} + +// handleChatCompletionsErrorResponse reads an upstream error and returns it in +// OpenAI Chat Completions error format. +func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse( + resp *http.Response, + c *gin.Context, + account *Account, +) (*OpenAIForwardResult, error) { + return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError) +} + +// handleChatBufferedStreamingResponse reads all Responses SSE events from the +// upstream, finds the terminal event, converts to a Chat Completions JSON +// response, and writes it to the client. +func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var finalResponse *apicompat.ResponsesResponse + var usage OpenAIUsage + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + payload := line[6:] + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai chat_completions buffered: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil { + finalResponse = event.Response + if event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai chat_completions buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResponse == nil { + writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event") + return nil, fmt.Errorf("upstream stream ended without terminal event") + } + + chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel) + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, chatResp) + + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// handleChatStreamingResponse reads Responses SSE events from upstream, +// converts each to Chat Completions SSE chunks, and writes them to the client. +func (s *OpenAIGatewayService) handleChatStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + includeUsage bool, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + state := apicompat.NewResponsesEventToChatState() + state.Model = originalModel + state.IncludeUsage = includeUsage + + var usage OpenAIUsage + var firstTokenMs *int + firstChunk := true + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + resultWithUsage := func() *OpenAIForwardResult { + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + } + + processDataLine := func(payload string) bool { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai chat_completions stream: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + return false + } + + // Extract usage from completion events + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil && event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } + + chunks := apicompat.ResponsesEventToChatChunks(&event, state) + for _, chunk := range chunks { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + logger.L().Warn("openai chat_completions stream: failed to marshal chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + logger.L().Info("openai chat_completions stream: client disconnected", + zap.String("request_id", requestID), + ) + return true + } + } + if len(chunks) > 0 { + c.Writer.Flush() + } + return false + } + + finalizeStream := func() (*OpenAIForwardResult, error) { + if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 { + for _, chunk := range finalChunks { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + continue + } + fmt.Fprint(c.Writer, sse) //nolint:errcheck + } + } + // Send [DONE] sentinel + fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck + c.Writer.Flush() + return resultWithUsage(), nil + } + + handleScanErr := func(err error) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai chat_completions stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + // Determine keepalive interval + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + + // No keepalive: fast synchronous path + if keepaliveInterval <= 0 { + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + } + handleScanErr(scanner.Err()) + return finalizeStream() + } + + // With keepalive: goroutine + channel + select + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + go func() { + defer close(events) + for scanner.Scan() { + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + keepaliveTicker := time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + lastDataAt := time.Now() + + for { + select { + case ev, ok := <-events: + if !ok { + return finalizeStream() + } + if ev.err != nil { + handleScanErr(ev.err) + return finalizeStream() + } + lastDataAt = time.Now() + line := ev.line + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + + case <-keepaliveTicker.C: + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // Send SSE comment as keepalive + if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil { + logger.L().Info("openai chat_completions stream: client disconnected during keepalive", + zap.String("request_id", requestID), + ) + return resultWithUsage(), nil + } + c.Writer.Flush() + } + } +} + +// writeChatCompletionsError writes an error response in OpenAI Chat Completions format. +func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go new file mode 100644 index 00000000..58714571 --- /dev/null +++ b/backend/internal/service/openai_gateway_messages.go @@ -0,0 +1,538 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// ForwardAsAnthropic accepts an Anthropic Messages request body, converts it +// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts +// the response back to Anthropic Messages format. This enables Claude Code +// clients to access OpenAI models through the standard /v1/messages endpoint. +func (s *OpenAIGatewayService) ForwardAsAnthropic( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + promptCacheKey string, + defaultMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + + // 1. Parse Anthropic request + var anthropicReq apicompat.AnthropicRequest + if err := json.Unmarshal(body, &anthropicReq); err != nil { + return nil, fmt.Errorf("parse anthropic request: %w", err) + } + originalModel := anthropicReq.Model + clientStream := anthropicReq.Stream // client's original stream preference + + // 2. Convert Anthropic → Responses + responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq) + if err != nil { + return nil, fmt.Errorf("convert anthropic to responses: %w", err) + } + + // Upstream always uses streaming (upstream may not support sync mode). + // The client's original preference determines the response format. + responsesReq.Stream = true + isStream := true + + // 2b. Handle BetaFastMode → service_tier: "priority" + if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) { + responsesReq.ServiceTier = "priority" + } + + // 3. Model mapping + mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + responsesReq.Model = mappedModel + + logger.L().Debug("openai messages: model mapping applied", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("mapped_model", mappedModel), + zap.Bool("stream", isStream), + ) + + // 4. Marshal Responses request body, then apply OAuth codex transform + responsesBody, err := json.Marshal(responsesReq) + if err != nil { + return nil, fmt.Errorf("marshal responses request: %w", err) + } + + if account.Type == AccountTypeOAuth { + var reqBody map[string]any + if err := json.Unmarshal(responsesBody, &reqBody); err != nil { + return nil, fmt.Errorf("unmarshal for codex transform: %w", err) + } + codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.PromptCacheKey != "" { + promptCacheKey = codexResult.PromptCacheKey + } else if promptCacheKey != "" { + reqBody["prompt_cache_key"] = promptCacheKey + } + // OAuth codex transform forces stream=true upstream, so always use + // the streaming response handler regardless of what the client asked. + isStream = true + responsesBody, err = json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("remarshal after codex transform: %w", err) + } + } + + // 5. Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 6. Build upstream request + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false) + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + // Override session_id with a deterministic UUID derived from the isolated + // session key, ensuring different API keys produce different upstream sessions. + if promptCacheKey != "" { + apiKeyID := getAPIKeyIDFromContext(c) + upstreamReq.Header.Set("session_id", generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey))) + } + + // 7. Send request + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 8. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } + } + // Non-failover error: return Anthropic-formatted error to client + return s.handleAnthropicErrorResponse(resp, c, account) + } + + // 9. Handle normal response + // Upstream is always streaming; choose response format based on client preference. + var result *OpenAIForwardResult + var handleErr error + if clientStream { + result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime) + } else { + // Client wants JSON: buffer the streaming response and assemble a JSON reply. + result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + } + + // Propagate ServiceTier and ReasoningEffort to result for billing + if handleErr == nil && result != nil { + if responsesReq.ServiceTier != "" { + st := responsesReq.ServiceTier + result.ServiceTier = &st + } + if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" { + re := responsesReq.Reasoning.Effort + result.ReasoningEffort = &re + } + } + + // Extract and save Codex usage snapshot from response headers (for OAuth accounts) + if handleErr == nil && account.Type == AccountTypeOAuth { + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } + } + + return result, handleErr +} + +// handleAnthropicErrorResponse reads an upstream error and returns it in +// Anthropic error format. +func (s *OpenAIGatewayService) handleAnthropicErrorResponse( + resp *http.Response, + c *gin.Context, + account *Account, +) (*OpenAIForwardResult, error) { + return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError) +} + +// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from +// the upstream streaming response, finds the terminal event (response.completed +// / response.incomplete / response.failed), converts the complete response to +// Anthropic Messages JSON format, and writes it to the client. +// This is used when the client requested stream=false but the upstream is always +// streaming. +func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var finalResponse *apicompat.ResponsesResponse + var usage OpenAIUsage + + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + payload := line[6:] + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai messages buffered: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + + // Terminal events carry the complete ResponsesResponse with output + usage. + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil { + finalResponse = event.Response + if event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai messages buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResponse == nil { + writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event") + return nil, fmt.Errorf("upstream stream ended without terminal event") + } + + anthropicResp := apicompat.ResponsesToAnthropic(finalResponse, originalModel) + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, anthropicResp) + + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// handleAnthropicStreamingResponse reads Responses SSE events from upstream, +// converts each to Anthropic SSE events, and writes them to the client. +// When StreamKeepaliveInterval is configured, it uses a goroutine + channel +// pattern to send Anthropic ping events during periods of upstream silence, +// preventing proxy/client timeout disconnections. +func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + state := apicompat.NewResponsesEventToAnthropicState() + state.Model = originalModel + var usage OpenAIUsage + var firstTokenMs *int + firstChunk := true + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + // resultWithUsage builds the final result snapshot. + resultWithUsage := func() *OpenAIForwardResult { + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + } + + // processDataLine handles a single "data: ..." SSE line from upstream. + // Returns (clientDisconnected bool). + processDataLine := func(payload string) bool { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai messages stream: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + return false + } + + // Extract usage from completion events + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil && event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } + + // Convert to Anthropic events + events := apicompat.ResponsesEventToAnthropicEvents(&event, state) + for _, evt := range events { + sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) + if err != nil { + logger.L().Warn("openai messages stream: failed to marshal event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + logger.L().Info("openai messages stream: client disconnected", + zap.String("request_id", requestID), + ) + return true + } + } + if len(events) > 0 { + c.Writer.Flush() + } + return false + } + + // finalizeStream sends any remaining Anthropic events and returns the result. + finalizeStream := func() (*OpenAIForwardResult, error) { + if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { + for _, evt := range finalEvents { + sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) + if err != nil { + continue + } + fmt.Fprint(c.Writer, sse) //nolint:errcheck + } + c.Writer.Flush() + } + return resultWithUsage(), nil + } + + // handleScanErr logs scanner errors if meaningful. + handleScanErr := func(err error) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai messages stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + // ── Determine keepalive interval ── + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + + // ── No keepalive: fast synchronous path (no goroutine overhead) ── + if keepaliveInterval <= 0 { + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + } + handleScanErr(scanner.Err()) + return finalizeStream() + } + + // ── With keepalive: goroutine + channel + select ── + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + go func() { + defer close(events) + for scanner.Scan() { + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + keepaliveTicker := time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + lastDataAt := time.Now() + + for { + select { + case ev, ok := <-events: + if !ok { + // Upstream closed + return finalizeStream() + } + if ev.err != nil { + handleScanErr(ev.err) + return finalizeStream() + } + lastDataAt = time.Now() + line := ev.line + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + + case <-keepaliveTicker.C: + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // Send Anthropic-format ping event + if _, err := fmt.Fprint(c.Writer, "event: ping\ndata: {\"type\":\"ping\"}\n\n"); err != nil { + // Client disconnected + logger.L().Info("openai messages stream: client disconnected during keepalive", + zap.String("request_id", requestID), + ) + return resultWithUsage(), nil + } + c.Writer.Flush() + } + } +} + +// writeAnthropicError writes an error response in Anthropic Messages API format. +func writeAnthropicError(c *gin.Context, statusCode int, errType, message string) { + c.JSON(statusCode, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go new file mode 100644 index 00000000..ada7d805 --- /dev/null +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -0,0 +1,946 @@ +package service + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +type openAIRecordUsageLogRepoStub struct { + UsageLogRepository + + inserted bool + err error + calls int + lastLog *UsageLog + lastCtxErr error +} + +func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { + s.calls++ + s.lastLog = log + s.lastCtxErr = ctx.Err() + return s.inserted, s.err +} + +type openAIRecordUsageBillingRepoStub struct { + UsageBillingRepository + + result *UsageBillingApplyResult + err error + calls int + lastCmd *UsageBillingCommand + lastCtxErr error +} + +func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) { + s.calls++ + s.lastCmd = cmd + s.lastCtxErr = ctx.Err() + if s.err != nil { + return nil, s.err + } + if s.result != nil { + return s.result, nil + } + return &UsageBillingApplyResult{Applied: true}, nil +} + +type openAIRecordUsageUserRepoStub struct { + UserRepository + + deductCalls int + deductErr error + lastAmount float64 + lastCtxErr error +} + +func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error { + s.deductCalls++ + s.lastAmount = amount + s.lastCtxErr = ctx.Err() + return s.deductErr +} + +type openAIRecordUsageSubRepoStub struct { + UserSubscriptionRepository + + incrementCalls int + incrementErr error + lastCtxErr error +} + +func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + s.incrementCalls++ + s.lastCtxErr = ctx.Err() + return s.incrementErr +} + +type openAIRecordUsageAPIKeyQuotaStub struct { + quotaCalls int + rateLimitCalls int + err error + lastAmount float64 + lastQuotaCtxErr error + lastRateLimitCtxErr error +} + +func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { + s.quotaCalls++ + s.lastAmount = cost + s.lastQuotaCtxErr = ctx.Err() + return s.err +} + +func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { + s.rateLimitCalls++ + s.lastAmount = cost + s.lastRateLimitCtxErr = ctx.Err() + return s.err +} + +type openAIUserGroupRateRepoStub struct { + UserGroupRateRepository + + rate *float64 + err error + calls int +} + +func (s *openAIUserGroupRateRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + s.calls++ + if s.err != nil { + return nil, s.err + } + return s.rate, nil +} + +func i64p(v int64) *int64 { + return &v +} + +func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 1.1 + svc := NewOpenAIGatewayService( + nil, + usageRepo, + nil, + userRepo, + subRepo, + rateRepo, + nil, + cfg, + nil, + nil, + NewBillingService(cfg, nil), + nil, + &BillingCacheService{}, + nil, + &DeferredService{}, + nil, + ) + svc.userGroupRateResolver = newUserGroupRateResolver( + rateRepo, + nil, + resolveUserGroupRateCacheTTL(cfg), + nil, + "service.openai_gateway.test", + ) + return svc +} + +func newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService { + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo) + svc.usageBillingRepo = billingRepo + return svc +} + +func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown { + t.Helper() + + cost, err := svc.billingService.CalculateCost(model, UsageTokens{ + InputTokens: max(usage.InputTokens-usage.CacheReadInputTokens, 0), + OutputTokens: usage.OutputTokens, + CacheCreationTokens: usage.CacheCreationInputTokens, + CacheReadTokens: usage.CacheReadInputTokens, + }, multiplier) + require.NoError(t, err) + return cost +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) { + groupID := int64(11) + groupRate := 1.4 + userRate := 1.8 + usage := OpenAIUsage{InputTokens: 15, OutputTokens: 4, CacheReadInputTokens: 3} + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + rateRepo := &openAIUserGroupRateRepoStub{rate: &userRate} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_user_group_rate", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1001, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: groupRate, + }, + }, + User: &User{ID: 2001}, + Account: &Account{ID: 3001}, + }) + + require.NoError(t, err) + require.Equal(t, 1, rateRepo.calls) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, userRate, usageRepo.lastLog.RateMultiplier) + require.Equal(t, 12, usageRepo.lastLog.InputTokens) + require.Equal(t, 3, usageRepo.lastLog.CacheReadTokens) + + expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, userRate) + require.InDelta(t, expected.ActualCost, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12) + require.Equal(t, 1, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_IncludesEndpointMetadata(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + rateRepo := &openAIUserGroupRateRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_endpoint_metadata", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 2, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1002, + Group: &Group{RateMultiplier: 1}, + }, + User: &User{ID: 2002}, + Account: &Account{ID: 3002}, + InboundEndpoint: " /v1/chat/completions ", + UpstreamEndpoint: " /v1/responses ", + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.NotNil(t, usageRepo.lastLog.InboundEndpoint) + require.Equal(t, "/v1/chat/completions", *usageRepo.lastLog.InboundEndpoint) + require.NotNil(t, usageRepo.lastLog.UpstreamEndpoint) + require.Equal(t, "/v1/responses", *usageRepo.lastLog.UpstreamEndpoint) +} + +func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) { + groupID := int64(12) + groupRate := 1.6 + usage := OpenAIUsage{InputTokens: 10, OutputTokens: 5, CacheReadInputTokens: 2} + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + rateRepo := &openAIUserGroupRateRepoStub{err: errors.New("db unavailable")} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_group_default_on_error", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1002, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: groupRate, + }, + }, + User: &User{ID: 2002}, + Account: &Account{ID: 3002}, + }) + + require.NoError(t, err) + require.Equal(t, 1, rateRepo.calls) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier) + + expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, groupRate) + require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12) +} + +func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolverMissing(t *testing.T) { + groupID := int64(13) + groupRate := 1.25 + usage := OpenAIUsage{InputTokens: 9, OutputTokens: 4, CacheReadInputTokens: 1} + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + svc.userGroupRateResolver = nil + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_group_default_nil_resolver", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1003, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: groupRate, + }, + }, + User: &User{ID: 2003}, + Account: &Account{ID: 3003}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier) +} + +func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_duplicate", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1004}, + User: &User{ID: 2004}, + Account: &Account{ID: 3004}, + }) + + require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_duplicate_billing_key", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10045, + Quota: 100, + }, + User: &User{ID: 20045}, + Account: &Account{ID: 30045}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 0, quotaSvc.quotaCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError(t *testing.T) { + usage := OpenAIUsage{InputTokens: 8, OutputTokens: 4} + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: errors.New("usage log batch state uncertain")} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_usage_log_error", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10041}, + User: &User{ID: 20041}, + Account: &Account{ID: 30041}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_not_persisted", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10043, + Quota: 100, + }, + User: &User{ID: 20043}, + Account: &Account{ID: 30043}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 1, quotaSvc.quotaCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) { + usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2} + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_detached_billing_ctx", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10042, + Quota: 100, + }, + User: &User{ID: 20042}, + Account: &Account{ID: 30042}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, userRepo.deductCalls) + require.NoError(t, userRepo.lastCtxErr) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.NoError(t, quotaSvc.lastQuotaCtxErr) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_detached_billing_repo_ctx", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10046}, + User: &User{ID: 20046}, + Account: &Account{ID: 30046}, + }) + + require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) + require.NoError(t, billingRepo.lastCtxErr) + require.Equal(t, 1, usageRepo.calls) + require.NoError(t, usageRepo.lastCtxErr) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + + payloadHash := HashUsageRequestPayload([]byte(`{"model":"gpt-5","input":"hello"}`)) + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "openai_payload_hash", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "gpt-5", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + RequestPayloadHash: payloadHash, + }) + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash) +} + +func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-fallback") + err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10047}, + User: &User{ID: 20047}, + Account: &Account{ID: 30047}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "local:req-local-fallback", billingRepo.lastCmd.RequestID) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID) +} + +func TestOpenAIGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-stable-123") + err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "upstream-openai-volatile-456", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10049}, + User: &User{ID: 20049}, + Account: &Account{ID: 30049}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "client:openai-client-stable-123", billingRepo.lastCmd.RequestID) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "client:openai-client-stable-123", usageRepo.lastLog.RequestID) +} + +func TestOpenAIGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10050}, + User: &User{ID: 20050}, + Account: &Account{ID: 30050}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:")) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_billing_fail", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10048}, + User: &User{ID: 20048}, + Account: &Account{ID: 30048}, + }) + + require.Error(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 0, usageRepo.calls) +} + +func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) { + usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2} + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_quota_update", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1005, + Quota: 100, + }, + User: &User{ID: 2005}, + Account: &Account{ID: 3005}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.Equal(t, 0, quotaSvc.rateLimitCalls) + expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, 1.1) + require.InDelta(t, expected.ActualCost, quotaSvc.lastAmount, 1e-12) +} + +func TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_clamp_actual_input", + Usage: OpenAIUsage{ + InputTokens: 2, + OutputTokens: 1, + CacheReadInputTokens: 5, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1006}, + User: &User{ID: 2006}, + Account: &Account{ID: 3006}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, 0, usageRepo.lastLog.InputTokens) +} + +func TestOpenAIGatewayServiceRecordUsage_Gpt54LongContextBillsWholeSession(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_gpt54_long_context", + Usage: OpenAIUsage{ + InputTokens: 300000, + OutputTokens: 2000, + }, + Model: "gpt-5.4-2026-03-05", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1014}, + User: &User{ID: 2014}, + Account: &Account{ID: 3014}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + + expectedInput := 300000 * 2.5e-6 * 2.0 + expectedOutput := 2000 * 15e-6 * 1.5 + require.InDelta(t, expectedInput, usageRepo.lastLog.InputCost, 1e-10) + require.InDelta(t, expectedOutput, usageRepo.lastLog.OutputCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, usageRepo.lastLog.TotalCost, 1e-10) + require.InDelta(t, (expectedInput+expectedOutput)*1.1, usageRepo.lastLog.ActualCost, 1e-10) + require.Equal(t, 1, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_ServiceTierPriorityUsesFastPricing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + serviceTier := "priority" + usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_service_tier_priority", + ServiceTier: &serviceTier, + Usage: usage, + Model: "gpt-5.4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1015}, + User: &User{ID: 2015}, + Account: &Account{ID: 3015}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.NotNil(t, usageRepo.lastLog.ServiceTier) + require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) + + baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 100, OutputTokens: 50}, 1.0) + require.NoError(t, calcErr) + require.InDelta(t, baseCost.TotalCost*2, usageRepo.lastLog.TotalCost, 1e-10) +} + +func TestOpenAIGatewayServiceRecordUsage_ServiceTierFlexHalvesCost(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + serviceTier := "flex" + usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50, CacheReadInputTokens: 20} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_service_tier_flex", + ServiceTier: &serviceTier, + Usage: usage, + Model: "gpt-5.4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1016}, + User: &User{ID: 2016}, + Account: &Account{ID: 3016}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + + baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 80, OutputTokens: 50, CacheReadTokens: 20}, 1.0) + require.NoError(t, calcErr) + require.InDelta(t, baseCost.TotalCost*0.5, usageRepo.lastLog.TotalCost, 1e-10) +} + +func TestNormalizeOpenAIServiceTier(t *testing.T) { + t.Run("fast maps to priority", func(t *testing.T) { + got := normalizeOpenAIServiceTier(" fast ") + require.NotNil(t, got) + require.Equal(t, "priority", *got) + }) + + t.Run("default ignored", func(t *testing.T) { + require.Nil(t, normalizeOpenAIServiceTier("default")) + }) + + t.Run("invalid ignored", func(t *testing.T) { + require.Nil(t, normalizeOpenAIServiceTier("turbo")) + }) +} + +func TestExtractOpenAIServiceTier(t *testing.T) { + require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"})) + require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"})) + require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1})) + require.Nil(t, extractOpenAIServiceTier(nil)) +} + +func TestExtractOpenAIServiceTierFromBody(t *testing.T) { + require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`))) + require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`))) + require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`))) + require.Nil(t, extractOpenAIServiceTierFromBody(nil)) +} + +func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + serviceTier := "priority" + reasoning := "high" + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_billing_model_override", + BillingModel: "gpt-5.1-codex", + Model: "gpt-5.1", + ServiceTier: &serviceTier, + ReasoningEffort: &reasoning, + Usage: OpenAIUsage{ + InputTokens: 20, + OutputTokens: 10, + }, + Duration: 2 * time.Second, + FirstTokenMs: func() *int { v := 120; return &v }(), + }, + APIKey: &APIKey{ID: 10, GroupID: i64p(11), Group: &Group{ID: 11, RateMultiplier: 1.2}}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + UserAgent: "codex-cli/1.0", + IPAddress: "127.0.0.1", + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model) + require.NotNil(t, usageRepo.lastLog.ServiceTier) + require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) + require.NotNil(t, usageRepo.lastLog.ReasoningEffort) + require.Equal(t, reasoning, *usageRepo.lastLog.ReasoningEffort) + require.NotNil(t, usageRepo.lastLog.UserAgent) + require.Equal(t, "codex-cli/1.0", *usageRepo.lastLog.UserAgent) + require.NotNil(t, usageRepo.lastLog.IPAddress) + require.Equal(t, "127.0.0.1", *usageRepo.lastLog.IPAddress) + require.NotNil(t, usageRepo.lastLog.GroupID) + require.Equal(t, int64(11), *usageRepo.lastLog.GroupID) + require.Equal(t, 1, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + subscription := &UserSubscription{ID: 99} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_subscription_billing", + Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5}, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}}, + User: &User{ID: 200}, + Account: &Account{ID: 300}, + Subscription: subscription, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, BillingTypeSubscription, usageRepo.lastLog.BillingType) + require.NotNil(t, usageRepo.lastLog.SubscriptionID) + require.Equal(t, subscription.ID, *usageRepo.lastLog.SubscriptionID) + require.Equal(t, 1, subRepo.incrementCalls) + require.Equal(t, 0, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + svc.cfg.RunMode = config.RunModeSimple + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_simple_mode", + Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5}, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1000}, + User: &User{ID: 2000}, + Account: &Account{ID: 3000}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f624d92a..c8876edb 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -24,7 +24,9 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "github.com/cespare/xxhash/v2" "github.com/gin-gonic/gin" + "github.com/google/uuid" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.uber.org/zap" @@ -49,6 +51,10 @@ const ( openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond openAIWSRetryBackoffMaxDefault = 2 * time.Second openAIWSRetryJitterRatioDefault = 0.2 + openAICompactSessionSeedKey = "openai_compact_session_seed" + codexCLIVersion = "0.104.0" + // Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。 + openAICodexSnapshotPersistMinInterval = 30 * time.Second ) // OpenAI allowed headers whitelist (for non-passthrough). @@ -204,12 +210,21 @@ type OpenAIUsage struct { type OpenAIForwardResult struct { RequestID string Usage OpenAIUsage - Model string + Model string // 原始模型(用于响应和日志显示) + // BillingModel is the model used for cost calculation. + // When non-empty, CalculateCost uses this instead of Model. + // This is set by the Anthropic Messages conversion path where + // the mapped upstream model differs from the client-facing model. + BillingModel string + // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". + // Nil means the request did not specify a recognized tier. + ServiceTier *string // ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix. // Stored for usage records display; nil means not provided / not applicable. ReasoningEffort *string Stream bool OpenAIWSMode bool + ResponseHeaders http.Header Duration time.Duration FirstTokenMs *int } @@ -243,45 +258,92 @@ type openAIWSRetryMetrics struct { nonRetryableFastFallback atomic.Int64 } +type accountWriteThrottle struct { + minInterval time.Duration + mu sync.Mutex + lastByID map[int64]time.Time +} + +func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle { + return &accountWriteThrottle{ + minInterval: minInterval, + lastByID: make(map[int64]time.Time), + } +} + +func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool { + if t == nil || id <= 0 || t.minInterval <= 0 { + return true + } + + t.mu.Lock() + defer t.mu.Unlock() + + if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval { + return false + } + t.lastByID[id] = now + + if len(t.lastByID) > 4096 { + cutoff := now.Add(-4 * t.minInterval) + for accountID, writtenAt := range t.lastByID { + if writtenAt.Before(cutoff) { + delete(t.lastByID, accountID) + } + } + } + + return true +} + +var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval) + // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { - accountRepo AccountRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - cache GatewayCache - cfg *config.Config - codexDetector CodexClientRestrictionDetector - schedulerSnapshot *SchedulerSnapshotService - concurrencyService *ConcurrencyService - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - httpUpstream HTTPUpstream - deferredService *DeferredService - openAITokenProvider *OpenAITokenProvider - toolCorrector *CodexToolCorrector - openaiWSResolver OpenAIWSProtocolResolver + accountRepo AccountRepository + usageLogRepo UsageLogRepository + usageBillingRepo UsageBillingRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + cache GatewayCache + cfg *config.Config + codexDetector CodexClientRestrictionDetector + schedulerSnapshot *SchedulerSnapshotService + concurrencyService *ConcurrencyService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + userGroupRateResolver *userGroupRateResolver + httpUpstream HTTPUpstream + deferredService *DeferredService + openAITokenProvider *OpenAITokenProvider + toolCorrector *CodexToolCorrector + openaiWSResolver OpenAIWSProtocolResolver - openaiWSPoolOnce sync.Once - openaiWSStateStoreOnce sync.Once - openaiSchedulerOnce sync.Once - openaiWSPool *openAIWSConnPool - openaiWSStateStore OpenAIWSStateStore - openaiScheduler OpenAIAccountScheduler - openaiAccountStats *openAIAccountRuntimeStats + openaiWSPoolOnce sync.Once + openaiWSStateStoreOnce sync.Once + openaiSchedulerOnce sync.Once + openaiWSPassthroughDialerOnce sync.Once + openaiWSPool *openAIWSConnPool + openaiWSStateStore OpenAIWSStateStore + openaiScheduler OpenAIAccountScheduler + openaiWSPassthroughDialer openAIWSClientDialer + openaiAccountStats *openAIAccountRuntimeStats openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time openaiWSRetryMetrics openAIWSRetryMetrics responseHeaderFilter *responseheaders.CompiledHeaderFilter + codexSnapshotThrottle *accountWriteThrottle } // NewOpenAIGatewayService creates a new OpenAIGatewayService func NewOpenAIGatewayService( accountRepo AccountRepository, usageLogRepo UsageLogRepository, + usageBillingRepo UsageBillingRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, cache GatewayCache, cfg *config.Config, schedulerSnapshot *SchedulerSnapshotService, @@ -294,29 +356,55 @@ func NewOpenAIGatewayService( openAITokenProvider *OpenAITokenProvider, ) *OpenAIGatewayService { svc := &OpenAIGatewayService{ - accountRepo: accountRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, - codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - httpUpstream: httpUpstream, - deferredService: deferredService, - openAITokenProvider: openAITokenProvider, - toolCorrector: NewCodexToolCorrector(), - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - responseHeaderFilter: compileResponseHeaderFilter(cfg), + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + usageBillingRepo: usageBillingRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + cache: cache, + cfg: cfg, + codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + userGroupRateResolver: newUserGroupRateResolver( + userGroupRateRepo, + nil, + resolveUserGroupRateCacheTTL(cfg), + nil, + "service.openai_gateway", + ), + httpUpstream: httpUpstream, + deferredService: deferredService, + openAITokenProvider: openAITokenProvider, + toolCorrector: NewCodexToolCorrector(), + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + responseHeaderFilter: compileResponseHeaderFilter(cfg), + codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } svc.logOpenAIWSModeBootstrap() return svc } +func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { + if s != nil && s.codexSnapshotThrottle != nil { + return s.codexSnapshotThrottle + } + return defaultOpenAICodexSnapshotPersistThrottle +} + +func (s *OpenAIGatewayService) billingDeps() *billingDeps { + return &billingDeps{ + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + } +} + // CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。 // 应在应用优雅关闭时调用。 func (s *OpenAIGatewayService) CloseOpenAIWSPool() { @@ -393,6 +481,7 @@ func classifyOpenAIWSReconnectReason(err error) (string, bool) { "upgrade_required", "ws_unsupported", "auth_failed", + "invalid_encrypted_content", "previous_response_not_found": return reason, false } @@ -443,6 +532,14 @@ func resolveOpenAIWSFallbackErrorResponse(err error) (statusCode int, errType st } switch reason { + case "invalid_encrypted_content": + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + errType = "invalid_request_error" + if upstreamMessage == "" { + upstreamMessage = "encrypted content could not be verified" + } case "previous_response_not_found": if statusCode == 0 { statusCode = http.StatusBadRequest @@ -691,6 +788,20 @@ func getAPIKeyIDFromContext(c *gin.Context) int64 { return apiKey.ID } +// isolateOpenAISessionID 将 apiKeyID 混入 session 标识符, +// 确保不同 API Key 的用户即使使用相同的原始 session_id/conversation_id, +// 到达上游的标识符也不同,防止跨用户会话碰撞。 +func isolateOpenAISessionID(apiKeyID int64, raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + h := xxhash.New() + _, _ = fmt.Fprintf(h, "k%d:", apiKeyID) + _, _ = h.WriteString(raw) + return fmt.Sprintf("%016x", h.Sum64()) +} + func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult, body []byte) { if !result.Enabled { return @@ -804,8 +915,10 @@ func logOpenAIInstructionsRequiredDebug( } userAgent := "" + originator := "" if c != nil { userAgent = strings.TrimSpace(c.GetHeader("User-Agent")) + originator = strings.TrimSpace(c.GetHeader("originator")) } fields := []zap.Field{ @@ -815,7 +928,7 @@ func logOpenAIInstructionsRequiredDebug( zap.Int("upstream_status_code", upstreamStatusCode), zap.String("upstream_error_message", msg), zap.String("request_user_agent", userAgent), - zap.Bool("codex_official_client_match", openai.IsCodexCLIRequest(userAgent)), + zap.Bool("codex_official_client_match", openai.IsCodexOfficialClientByHeaders(userAgent, originator)), } fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody) @@ -876,6 +989,52 @@ func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg strin return false } +func isOpenAITransientProcessingError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool { + if upstreamStatusCode != http.StatusBadRequest { + return false + } + + match := func(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if strings.Contains(lower, "an error occurred while processing your request") { + return true + } + return strings.Contains(lower, "you can retry your request") && + strings.Contains(lower, "help.openai.com") && + strings.Contains(lower, "request id") + } + + if match(upstreamMsg) { + return true + } + if len(upstreamBody) == 0 { + return false + } + if match(gjson.GetBytes(upstreamBody, "error.message").String()) { + return true + } + return match(string(upstreamBody)) +} + +// ExtractSessionID extracts the raw session ID from headers or body without hashing. +// Used by ForwardAsAnthropic to pass as prompt_cache_key for upstream cache. +func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) string { + if c == nil { + return "" + } + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + } + return sessionID +} + // GenerateSessionHash generates a sticky-session hash for OpenAI requests. // // Priority: @@ -922,6 +1081,18 @@ func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, b return currentHash } +func resolveOpenAIUpstreamOriginator(c *gin.Context, isOfficialClient bool) string { + if c != nil { + if originator := strings.TrimSpace(c.GetHeader("originator")); originator != "" { + return originator + } + } + if isOfficialClient { + return "codex_cli_rs" + } + return "opencode" +} + // BindStickySession sets session -> account binding with standard TTL. func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { if sessionHash == "" || accountID <= 0 { @@ -966,7 +1137,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // 3. 按优先级 + LRU 选择最佳账号 // Select by priority + LRU - selected := s.selectBestAccount(accounts, requestedModel, excludedIDs) + selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs) if selected == nil { if requestedModel != "" { @@ -1039,7 +1210,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // // selectBestAccount selects the best account from candidates (priority + LRU). // Returns nil if no available account. -func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { +func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { var selected *Account for i := range accounts { @@ -1051,27 +1222,20 @@ func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedMo continue } - // 调度器快照可能暂时过时,这里重新检查可调度性和平台 - // Scheduler snapshots can be temporarily stale; re-check schedulability and platform - if !acc.IsSchedulable() || !acc.IsOpenAI() { - continue - } - - // 检查模型支持 - // Check model support - if requestedModel != "" && !acc.IsModelSupported(requestedModel) { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + if fresh == nil { continue } // 选择优先级最高且最久未使用的账号 // Select highest priority and least recently used if selected == nil { - selected = acc + selected = fresh continue } - if s.isBetterAccount(acc, selected) { - selected = acc + if s.isBetterAccount(fresh, selected) { + selected = fresh } } @@ -1163,7 +1327,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex return nil, err } if len(accounts) == 0 { - return nil, errors.New("no available accounts") + return nil, ErrNoAvailableAccounts } isExcluded := func(accountID int64) bool { @@ -1233,14 +1397,14 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex } if len(candidates) == 0 { - return nil, errors.New("no available accounts") + return nil, ErrNoAvailableAccounts } accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) for _, acc := range candidates { accountLoads = append(accountLoads, AccountWithConcurrency{ ID: acc.ID, - MaxConcurrency: acc.Concurrency, + MaxConcurrency: acc.EffectiveLoadFactor(), }) } @@ -1249,13 +1413,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, false) for _, acc := range ordered { - result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + if fresh == nil { + continue + } + result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, acc.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ - Account: acc, + Account: fresh, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil @@ -1299,13 +1467,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex shuffleWithinSortGroups(available) for _, item := range available { - result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel) + if fresh == nil { + continue + } + result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, item.account.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ - Account: item.account, + Account: fresh, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil @@ -1317,18 +1489,22 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex // ============ Layer 3: Fallback wait ============ sortAccountsByPriorityAndLastUsed(candidates, false) for _, acc := range candidates { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + if fresh == nil { + continue + } return &AccountSelectionResult{ - Account: acc, + Account: fresh, WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, + AccountID: fresh.ID, + MaxConcurrency: fresh.Concurrency, Timeout: cfg.FallbackWaitTimeout, MaxWaiting: cfg.FallbackMaxWaiting, }, }, nil } - return nil, errors.New("no available accounts") + return nil, ErrNoAvailableAccounts } func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) { @@ -1343,7 +1519,7 @@ func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, grou } else if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, PlatformOpenAI) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) @@ -1358,11 +1534,44 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } -func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { - if s.schedulerSnapshot != nil { - return s.schedulerSnapshot.GetAccount(ctx, accountID) +func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string) *Account { + if account == nil { + return nil } - return s.accountRepo.GetByID(ctx, accountID) + + fresh := account + if s.schedulerSnapshot != nil { + current, err := s.getSchedulableAccount(ctx, account.ID) + if err != nil || current == nil { + return nil + } + fresh = current + } + + if !fresh.IsSchedulable() || !fresh.IsOpenAI() { + return nil + } + if requestedModel != "" && !fresh.IsModelSupported(requestedModel) { + return nil + } + return fresh +} + +func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { + var ( + account *Account + err error + ) + if s.schedulerSnapshot != nil { + account, err = s.schedulerSnapshot.GetAccount(ctx, accountID) + } else { + account, err = s.accountRepo.GetByID(ctx, accountID) + } + if err != nil || account == nil { + return account, err + } + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, time.Now()) + return account, nil } func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { @@ -1417,6 +1626,13 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool } } +func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode int, upstreamMsg string, upstreamBody []byte) bool { + if s.shouldFailoverUpstreamError(statusCode) { + return true + } + return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody) +} + func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) @@ -1443,7 +1659,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) originalModel := reqModel - isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) clientTransport := GetOpenAIClientTransport(c) // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。 @@ -1551,13 +1767,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco patchDisabled = true } - // 非透传模式下,保持历史行为:非 Codex CLI 请求在 instructions 为空时注入默认指令。 - if !isCodexCLI && isInstructionsEmpty(reqBody) { - if instructions := strings.TrimSpace(GetOpenCodeInstructions()); instructions != "" { - reqBody["instructions"] = instructions - bodyModified = true - markPatchSet("instructions", instructions) - } + // 非透传模式下,instructions 为空时注入默认指令。 + if isInstructionsEmpty(reqBody) { + reqBody["instructions"] = "You are a helpful coding assistant." + bodyModified = true + markPatchSet("instructions", "You are a helpful coding assistant.") } // 对所有请求执行模型映射(包含 Codex CLI)。 @@ -1580,6 +1794,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco bodyModified = true markPatchSet("model", normalizedModel) } + + // 移除 gpt-5.2-codex 以下的版本 verbosity 参数 + // 确保高版本模型向低版本模型映射不报错 + if !SupportsVerbosity(normalizedModel) { + if text, ok := reqBody["text"].(map[string]any); ok { + delete(text, "verbosity") + } + } } // 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。 @@ -1593,7 +1815,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } if account.Type == AccountTypeOAuth { - codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI) + codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isOpenAIResponsesCompactPath(c)) if codexResult.Modified { bodyModified = true disablePatch() @@ -1726,6 +1948,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco var wsErr error wsLastFailureReason := "" wsPrevResponseRecoveryTried := false + wsInvalidEncryptedContentRecoveryTried := false recoverPrevResponseNotFound := func(attempt int) bool { if wsPrevResponseRecoveryTried { return false @@ -1758,6 +1981,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ) return true } + recoverInvalidEncryptedContent := func(attempt int) bool { + if wsInvalidEncryptedContentRecoveryTried { + return false + } + removedReasoningItems := trimOpenAIEncryptedReasoningItems(wsReqBody) + if !removedReasoningItems { + logOpenAIWSModeInfo( + "reconnect_invalid_encrypted_content_recovery_skip account_id=%d attempt=%d reason=missing_encrypted_reasoning_items", + account.ID, + attempt, + ) + return false + } + previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id") + hasFunctionCallOutput := HasFunctionCallOutput(wsReqBody) + if previousResponseID != "" && !hasFunctionCallOutput { + delete(wsReqBody, "previous_response_id") + } + wsInvalidEncryptedContentRecoveryTried = true + logOpenAIWSModeInfo( + "reconnect_invalid_encrypted_content_recovery account_id=%d attempt=%d action=drop_encrypted_reasoning_items retry=1 previous_response_id_present=%v previous_response_id=%s previous_response_id_kind=%s has_function_call_output=%v dropped_previous_response_id=%v", + account.ID, + attempt, + previousResponseID != "", + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)), + hasFunctionCallOutput, + previousResponseID != "" && !hasFunctionCallOutput, + ) + return true + } retryBudget := s.openAIWSRetryTotalBudget() retryStartedAt := time.Now() wsRetryLoop: @@ -1794,6 +2048,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) { continue } + if reason == "invalid_encrypted_content" && recoverInvalidEncryptedContent(attempt) { + continue + } if retryable && attempt < maxAttempts { backoff := s.openAIWSRetryBackoff(attempt) if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget { @@ -1877,118 +2134,143 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco return nil, wsErr } - // Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) - if err != nil { - return nil, err - } + httpInvalidEncryptedContentRetryTried := false + for { + // Build upstream request + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) + releaseUpstreamCtx() + if err != nil { + return nil, err + } - // Get proxy URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } + // Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } - // Send request - upstreamStart := time.Now() - resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) - if err != nil { - // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - setOpsUpstreamError(c, 0, safeErr, "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - Kind: "request_error", - Message: safeErr, - }) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream request failed", - }, - }) - return nil, fmt.Errorf("upstream request failed: %s", safeErr) - } - defer func() { _ = resp.Body.Close() }() + // Send request + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } - // Handle error response - if resp.StatusCode >= 400 { - if s.shouldFailoverUpstreamError(resp.StatusCode) { + // Handle error response + if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 + upstreamCode := extractUpstreamErrorCode(respBody) + if !httpInvalidEncryptedContentRetryTried && resp.StatusCode == http.StatusBadRequest && upstreamCode == "invalid_encrypted_content" { + if trimOpenAIEncryptedReasoningItems(reqBody) { + body, err = json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("serialize invalid_encrypted_content retry body: %w", err) + } + setOpsUpstreamRequestBody(c, body) + httpInvalidEncryptedContentRetryTried = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Retrying non-WSv2 request once after invalid_encrypted_content (account: %s)", account.Name) + continue } - upstreamDetail = truncateString(string(respBody), maxBytes) + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Skip non-WSv2 invalid_encrypted_content retry because encrypted reasoning items are missing (account: %s)", account.Name) } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) - s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } + } + return s.handleErrorResponse(ctx, resp, c, account, body) } - return s.handleErrorResponse(ctx, resp, c, account, body) - } + defer func() { _ = resp.Body.Close() }() - // Handle normal response - var usage *OpenAIUsage - var firstTokenMs *int - if reqStream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel) - if err != nil { - return nil, err + // Handle normal response + var usage *OpenAIUsage + var firstTokenMs *int + if reqStream { + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + } else { + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel) + if err != nil { + return nil, err + } } - usage = streamResult.usage - firstTokenMs = streamResult.firstTokenMs - } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel) - if err != nil { - return nil, err + + // Extract and save Codex usage snapshot from response headers (for OAuth accounts) + if account.Type == AccountTypeOAuth { + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } } - } - // Extract and save Codex usage snapshot from response headers (for OAuth accounts) - if account.Type == AccountTypeOAuth { - if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { - s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + if usage == nil { + usage = &OpenAIUsage{} } + + reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) + serviceTier := extractOpenAIServiceTier(reqBody) + + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: originalModel, + ServiceTier: serviceTier, + ReasoningEffort: reasoningEffort, + Stream: reqStream, + OpenAIWSMode: false, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil } - - if usage == nil { - usage = &OpenAIUsage{} - } - - reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) - - return &OpenAIForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: *usage, - Model: originalModel, - ReasoningEffort: reasoningEffort, - Stream: reqStream, - OpenAIWSMode: false, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - }, nil } func (s *OpenAIGatewayService) forwardOpenAIPassthrough( @@ -2025,14 +2307,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason) } - normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body) + normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body, isOpenAIResponsesCompactPath(c)) if err != nil { return nil, err } if normalized { body = normalizedBody - reqStream = true } + reqStream = gjson.GetBytes(body, "stream").Bool() } logger.LegacyPrintf("service.openai_gateway", @@ -2064,7 +2346,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( return nil, err } - upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -2137,6 +2421,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: reqModel, + ServiceTier: extractOpenAIServiceTierFromBody(body), ReasoningEffort: reasoningEffort, Stream: reqStream, OpenAIWSMode: false, @@ -2197,6 +2482,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( targetURL = buildOpenAIResponsesURL(validatedURL) } } + targetURL = appendOpenAIResponsesRequestPathSuffix(targetURL, openAIResponsesRequestPathSuffix(c)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) if err != nil { @@ -2230,7 +2516,19 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { req.Header.Set("chatgpt-account-id", chatgptAccountID) } - if req.Header.Get("accept") == "" { + apiKeyID := getAPIKeyIDFromContext(c) + // 先保存客户端原始值,再做 compact 补充,避免后续统一隔离时读到已处理的值。 + clientSessionID := strings.TrimSpace(req.Header.Get("session_id")) + clientConversationID := strings.TrimSpace(req.Header.Get("conversation_id")) + if isOpenAIResponsesCompactPath(c) { + req.Header.Set("accept", "application/json") + if req.Header.Get("version") == "" { + req.Header.Set("version", codexCLIVersion) + } + if clientSessionID == "" { + clientSessionID = resolveOpenAICompactSessionID(c) + } + } else if req.Header.Get("accept") == "" { req.Header.Set("accept", "text/event-stream") } if req.Header.Get("OpenAI-Beta") == "" { @@ -2239,13 +2537,18 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( if req.Header.Get("originator") == "" { req.Header.Set("originator", "codex_cli_rs") } - if promptCacheKey != "" { - if req.Header.Get("conversation_id") == "" { - req.Header.Set("conversation_id", promptCacheKey) - } - if req.Header.Get("session_id") == "" { - req.Header.Set("session_id", promptCacheKey) - } + // 用隔离后的 session 标识符覆盖客户端透传值,防止跨用户会话碰撞。 + if clientSessionID == "" { + clientSessionID = promptCacheKey + } + if clientConversationID == "" { + clientConversationID = promptCacheKey + } + if clientSessionID != "" { + req.Header.Set("session_id", isolateOpenAISessionID(apiKeyID, clientSessionID)) + } + if clientConversationID != "" { + req.Header.Set("conversation_id", isolateOpenAISessionID(apiKeyID, clientConversationID)) } } @@ -2391,6 +2694,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( var firstTokenMs *int clientDisconnected := false sawDone := false + sawTerminalEvent := false upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) scanner := bufio.NewScanner(resp.Body) @@ -2410,6 +2714,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( if trimmedData == "[DONE]" { sawDone = true } + if openAIStreamEventIsTerminal(trimmedData) { + sawTerminalEvent = true + } if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms @@ -2427,19 +2734,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( } } if err := scanner.Err(); err != nil { - if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err) + if sawTerminalEvent { return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil } + if clientDisconnected { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err) + } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.openai_gateway", - "[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v", - account.ID, - upstreamRequestID, - err, - ctx.Err(), - ) - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err) } if errors.Is(err, bufio.ErrTooLong) { logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) @@ -2453,12 +2755,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ) return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) } - if !clientDisconnected && !sawDone && ctx.Err() == nil { + if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil { logger.FromContext(ctx).With( zap.String("component", "service.openai_gateway"), zap.Int64("account_id", account.ID), zap.String("upstream_request_id", upstreamRequestID), ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event") } return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil @@ -2577,6 +2880,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. default: targetURL = openaiPlatformAPIURL } + targetURL = appendOpenAIResponsesRequestPathSuffix(targetURL, openAIResponsesRequestPathSuffix(c)) req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { @@ -2607,16 +2911,27 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. } } if account.Type == AccountTypeOAuth { + // 清除客户端透传的 session 头,后续用隔离后的值重新设置,防止跨用户会话碰撞。 + req.Header.Del("conversation_id") + req.Header.Del("session_id") + req.Header.Set("OpenAI-Beta", "responses=experimental") - if isCodexCLI { - req.Header.Set("originator", "codex_cli_rs") + req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) + apiKeyID := getAPIKeyIDFromContext(c) + if isOpenAIResponsesCompactPath(c) { + req.Header.Set("accept", "application/json") + if req.Header.Get("version") == "" { + req.Header.Set("version", codexCLIVersion) + } + compactSession := resolveOpenAICompactSessionID(c) + req.Header.Set("session_id", isolateOpenAISessionID(apiKeyID, compactSession)) } else { - req.Header.Set("originator", "opencode") + req.Header.Set("accept", "text/event-stream") } - req.Header.Set("accept", "text/event-stream") if promptCacheKey != "" { - req.Header.Set("conversation_id", promptCacheKey) - req.Header.Set("session_id", promptCacheKey) + isolated := isolateOpenAISessionID(apiKeyID, promptCacheKey) + req.Header.Set("conversation_id", isolated) + req.Header.Set("session_id", isolated) } } @@ -2741,7 +3056,11 @@ func (s *OpenAIGatewayService) handleErrorResponse( Detail: upstreamDetail, }) if shouldDisable { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: body, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } } // Return appropriate error response @@ -2784,6 +3103,120 @@ func (s *OpenAIGatewayService) handleErrorResponse( return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) } +// compatErrorWriter is the signature for format-specific error writers used by +// the compat paths (Chat Completions and Anthropic Messages). +type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string) + +// handleCompatErrorResponse is the shared non-failover error handler for the +// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of +// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit +// tracking, secondary failover) but delegates the final error write to the +// format-specific writer function. +func (s *OpenAIGatewayService) handleCompatErrorResponse( + resp *http.Response, + c *gin.Context, + account *Account, + writeError compatErrorWriter, +) (*OpenAIForwardResult, error) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + if upstreamMsg == "" { + upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode) + } + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + // Apply error passthrough rules + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, account.Platform, resp.StatusCode, body, + http.StatusBadGateway, "api_error", "Upstream request failed", + ); matched { + writeError(c, status, errType, errMsg) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + + // Check custom error codes — if the account does not handle this status, + // return a generic error without exposing upstream details. + if !account.ShouldHandleErrorCode(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error") + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg) + } + + // Track rate limits and decide whether to trigger secondary failover. + shouldDisable := false + if s.rateLimitService != nil { + shouldDisable = s.rateLimitService.HandleUpstreamError( + c.Request.Context(), account, resp.StatusCode, resp.Header, body, + ) + } + kind := "http_error" + if shouldDisable { + kind = "failover" + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: kind, + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if shouldDisable { + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: body, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + + // Map status code to error type and write response + errType := "api_error" + switch { + case resp.StatusCode == 400: + errType = "invalid_request_error" + case resp.StatusCode == 404: + errType = "not_found_error" + case resp.StatusCode == 429: + errType = "rate_limit_error" + case resp.StatusCode >= 500: + errType = "api_error" + } + + writeError(c, resp.StatusCode, errType, upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) +} + // openaiStreamingResult streaming response result type openaiStreamingResult struct { usage *OpenAIUsage @@ -2867,6 +3300,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp // 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。 errorEventSent := false clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage + sawTerminalEvent := false sendErrorEvent := func(reason string) { if errorEventSent || clientDisconnected { return @@ -2897,22 +3331,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") } } + if !sawTerminalEvent { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } return resultWithUsage(), nil } handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) { if scanErr == nil { return nil, nil, false } + if sawTerminalEvent { + logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr) + return resultWithUsage(), nil, true + } // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) { - logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage") - return resultWithUsage(), nil, true + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true } // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr) - return resultWithUsage(), nil, true + return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true } if errors.Is(scanErr, bufio.ErrTooLong) { logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr) @@ -2935,6 +3374,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } dataBytes := []byte(data) + if openAIStreamEventIsTerminal(data) { + sawTerminalEvent = true + } // Correct Codex tool calls if needed (apply_patch -> edit, etc.) if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { @@ -3051,8 +3493,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp continue } if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage") - return resultWithUsage(), nil + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") } logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 @@ -3150,11 +3591,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { return } - // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 - if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) { + // 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。 + if len(data) < 72 { return } - if gjson.GetBytes(data, "type").String() != "response.completed" { + eventType := gjson.GetBytes(data, "type").String() + if eventType != "response.completed" && eventType != "response.done" { return } @@ -3249,6 +3691,14 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. // Correct tool calls in final response body = s.correctToolCallsInResponseBody(body) } else { + terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText) + if terminalOK && terminalType == "response.failed" { + msg := extractOpenAISSEErrorMessage(terminalPayload) + if msg == "" { + msg = "Upstream compact response failed" + } + return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg) + } usage = s.parseSSEUsageFromBody(bodyText) if originalModel != mappedModel { bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel) @@ -3270,6 +3720,51 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. return usage, nil } +func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) { + lines := strings.Split(body, "\n") + for _, line := range lines { + data, ok := extractOpenAISSEDataLine(line) + if !ok || data == "" || data == "[DONE]" { + continue + } + eventType := strings.TrimSpace(gjson.Get(data, "type").String()) + switch eventType { + case "response.completed", "response.done", "response.failed": + return eventType, []byte(data), true + } + } + return "", nil, false +} + +func extractOpenAISSEErrorMessage(payload []byte) string { + if len(payload) == 0 { + return "" + } + for _, path := range []string{"response.error.message", "error.message", "message"} { + if msg := strings.TrimSpace(gjson.GetBytes(payload, path).String()); msg != "" { + return sanitizeUpstreamErrorMessage(msg) + } + } + return sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(payload))) +} + +func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.Response, c *gin.Context, message string) error { + message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message)) + if message == "" { + message = "Upstream returned an invalid non-streaming response" + } + setOpsUpstreamError(c, http.StatusBadGateway, message, "") + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": message, + }, + }) + return fmt.Errorf("non-streaming openai protocol error: %s", message) +} + func extractCodexFinalResponse(body string) ([]byte, bool) { lines := strings.Split(body, "\n") for _, line := range lines { @@ -3351,6 +3846,198 @@ func buildOpenAIResponsesURL(base string) string { return normalized + "/v1/responses" } +func trimOpenAIEncryptedReasoningItems(reqBody map[string]any) bool { + if len(reqBody) == 0 { + return false + } + + inputValue, has := reqBody["input"] + if !has { + return false + } + + switch input := inputValue.(type) { + case []any: + filtered := input[:0] + changed := false + for _, item := range input { + nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item) + if itemChanged { + changed = true + } + if !keep { + continue + } + filtered = append(filtered, nextItem) + } + if !changed { + return false + } + if len(filtered) == 0 { + delete(reqBody, "input") + return true + } + reqBody["input"] = filtered + return true + case []map[string]any: + filtered := input[:0] + changed := false + for _, item := range input { + nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item) + if itemChanged { + changed = true + } + if !keep { + continue + } + nextMap, ok := nextItem.(map[string]any) + if !ok { + filtered = append(filtered, item) + continue + } + filtered = append(filtered, nextMap) + } + if !changed { + return false + } + if len(filtered) == 0 { + delete(reqBody, "input") + return true + } + reqBody["input"] = filtered + return true + case map[string]any: + nextItem, changed, keep := sanitizeEncryptedReasoningInputItem(input) + if !changed { + return false + } + if !keep { + delete(reqBody, "input") + return true + } + nextMap, ok := nextItem.(map[string]any) + if !ok { + return false + } + reqBody["input"] = nextMap + return true + default: + return false + } +} + +func sanitizeEncryptedReasoningInputItem(item any) (next any, changed bool, keep bool) { + inputItem, ok := item.(map[string]any) + if !ok { + return item, false, true + } + + itemType, _ := inputItem["type"].(string) + if strings.TrimSpace(itemType) != "reasoning" { + return item, false, true + } + + _, hasEncryptedContent := inputItem["encrypted_content"] + if !hasEncryptedContent { + return item, false, true + } + + delete(inputItem, "encrypted_content") + if len(inputItem) == 1 { + return nil, true, false + } + return inputItem, true, true +} + +func IsOpenAIResponsesCompactPathForTest(c *gin.Context) bool { + return isOpenAIResponsesCompactPath(c) +} + +func OpenAICompactSessionSeedKeyForTest() string { + return openAICompactSessionSeedKey +} + +func NormalizeOpenAICompactRequestBodyForTest(body []byte) ([]byte, bool, error) { + return normalizeOpenAICompactRequestBody(body) +} + +func isOpenAIResponsesCompactPath(c *gin.Context) bool { + suffix := strings.TrimSpace(openAIResponsesRequestPathSuffix(c)) + return suffix == "/compact" || strings.HasPrefix(suffix, "/compact/") +} + +func normalizeOpenAICompactRequestBody(body []byte) ([]byte, bool, error) { + if len(body) == 0 { + return body, false, nil + } + + normalized := []byte(`{}`) + for _, field := range []string{"model", "input", "instructions", "previous_response_id"} { + value := gjson.GetBytes(body, field) + if !value.Exists() { + continue + } + next, err := sjson.SetRawBytes(normalized, field, []byte(value.Raw)) + if err != nil { + return body, false, fmt.Errorf("normalize compact body %s: %w", field, err) + } + normalized = next + } + + if bytes.Equal(bytes.TrimSpace(body), bytes.TrimSpace(normalized)) { + return body, false, nil + } + return normalized, true, nil +} + +func resolveOpenAICompactSessionID(c *gin.Context) string { + if c != nil { + if sessionID := strings.TrimSpace(c.GetHeader("session_id")); sessionID != "" { + return sessionID + } + if conversationID := strings.TrimSpace(c.GetHeader("conversation_id")); conversationID != "" { + return conversationID + } + if seed, ok := c.Get(openAICompactSessionSeedKey); ok { + if seedStr, ok := seed.(string); ok && strings.TrimSpace(seedStr) != "" { + return strings.TrimSpace(seedStr) + } + } + } + return uuid.NewString() +} + +func openAIResponsesRequestPathSuffix(c *gin.Context) string { + if c == nil || c.Request == nil || c.Request.URL == nil { + return "" + } + normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/") + if normalizedPath == "" { + return "" + } + idx := strings.LastIndex(normalizedPath, "/responses") + if idx < 0 { + return "" + } + suffix := normalizedPath[idx+len("/responses"):] + if suffix == "" || suffix == "/" { + return "" + } + if !strings.HasPrefix(suffix, "/") { + return "" + } + return suffix +} + +func appendOpenAIResponsesRequestPathSuffix(baseURL, suffix string) string { + trimmedBase := strings.TrimRight(strings.TrimSpace(baseURL), "/") + trimmedSuffix := strings.TrimSpace(suffix) + if trimmedBase == "" || trimmedSuffix == "" { + return trimmedBase + } + return trimmedBase + trimmedSuffix +} + func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { // 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化 if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { @@ -3365,19 +4052,29 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { - Result *OpenAIForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - APIKeyService APIKeyQuotaUpdater + Result *OpenAIForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + InboundEndpoint string + UpstreamEndpoint string + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string + APIKeyService APIKeyQuotaUpdater } // RecordUsage records usage and deducts balance func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { result := input.Result + + // 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 + if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && + result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 { + return nil + } + apiKey := input.APIKey user := input.User account := input.Account @@ -3401,10 +4098,22 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Get rate multiplier multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { - multiplier = apiKey.Group.RateMultiplier + resolver := s.userGroupRateResolver + if resolver == nil { + resolver = newUserGroupRateResolver(nil, nil, resolveUserGroupRateCacheTTL(s.cfg), nil, "service.openai_gateway") + } + multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) } - cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier) + billingModel := result.Model + if result.BillingModel != "" { + billingModel = result.BillingModel + } + serviceTier := "" + if result.ServiceTier != nil { + serviceTier = strings.TrimSpace(*result.ServiceTier) + } + cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) if err != nil { cost = &CostBreakdown{ActualCost: 0} } @@ -3419,13 +4128,17 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Create usage log durationMs := int(result.Duration.Milliseconds()) accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, - Model: result.Model, + RequestID: requestID, + Model: billingModel, + ServiceTier: result.ServiceTier, ReasoningEffort: result.ReasoningEffort, + InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), + UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), InputTokens: actualInputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, @@ -3445,7 +4158,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec FirstTokenMs: result.FirstTokenMs, CreatedAt: time.Now(), } - // 添加 UserAgent if input.UserAgent != "" { usageLog.UserAgent = &input.UserAgent @@ -3463,37 +4175,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec usageLog.SubscriptionID = &subscription.ID } - inserted, err := s.usageLogRepo.Create(ctx, usageLog) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps(), s.usageBillingRepo) + return err + }() - // Deduct based on billing type - if isSubscriptionBilling { - if shouldBill && cost.TotalCost > 0 { - _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } - } else { - if shouldBill && cost.ActualCost > 0 { - _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - } + if billingErr != nil { + return billingErr } - - // Update API key quota if applicable (only for balance mode with quota set) - if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err) - } - } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") return nil } @@ -3655,6 +4362,69 @@ func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow return updates } +func codexUsagePercentExhausted(value *float64) bool { + return value != nil && *value >= 100-1e-9 +} + +func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) *time.Time { + if snapshot == nil { + return nil + } + normalized := snapshot.Normalize() + if normalized == nil { + return nil + } + baseTime := codexSnapshotBaseTime(snapshot, fallbackNow) + if codexUsagePercentExhausted(normalized.Used7dPercent) && normalized.Reset7dSeconds != nil { + resetAt := baseTime.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second) + return &resetAt + } + if codexUsagePercentExhausted(normalized.Used5hPercent) && normalized.Reset5hSeconds != nil { + resetAt := baseTime.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second) + return &resetAt + } + return nil +} + +func codexRateLimitResetAtFromExtra(extra map[string]any, now time.Time) *time.Time { + if len(extra) == 0 { + return nil + } + if progress := buildCodexUsageProgressFromExtra(extra, "7d", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) { + resetAt := progress.ResetsAt.UTC() + return &resetAt + } + if progress := buildCodexUsageProgressFromExtra(extra, "5h", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) { + resetAt := progress.ResetsAt.UTC() + return &resetAt + } + return nil +} + +func applyOpenAICodexRateLimitFromExtra(account *Account, now time.Time) (*time.Time, bool) { + if account == nil || !account.IsOpenAI() { + return nil, false + } + resetAt := codexRateLimitResetAtFromExtra(account.Extra, now) + if resetAt == nil { + return nil, false + } + if account.RateLimitResetAt != nil && now.Before(*account.RateLimitResetAt) && !account.RateLimitResetAt.Before(*resetAt) { + return account.RateLimitResetAt, false + } + account.RateLimitResetAt = resetAt + return resetAt, true +} + +func syncOpenAICodexRateLimitFromExtra(ctx context.Context, repo AccountRepository, account *Account, now time.Time) *time.Time { + resetAt, changed := applyOpenAICodexRateLimitFromExtra(account, now) + if !changed || resetAt == nil || repo == nil || account == nil || account.ID <= 0 { + return resetAt + } + _ = repo.SetRateLimited(ctx, account.ID, *resetAt) + return resetAt +} + // updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { if snapshot == nil { @@ -3664,19 +4434,38 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc return } - updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) - if len(updates) == 0 { + now := time.Now() + updates := buildCodexUsageExtraUpdates(snapshot, now) + resetAt := codexRateLimitResetAtFromSnapshot(snapshot, now) + if len(updates) == 0 && resetAt == nil { + return + } + shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now) + if !shouldPersistUpdates && resetAt == nil { return } - // Update account's Extra field asynchronously go func() { updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + if shouldPersistUpdates { + _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + } + if resetAt != nil { + _ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt) + } }() } +func (s *OpenAIGatewayService) UpdateCodexUsageSnapshotFromHeaders(ctx context.Context, accountID int64, headers http.Header) { + if accountID <= 0 || headers == nil { + return + } + if snapshot := ParseCodexRateLimitHeaders(headers); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, accountID, snapshot) + } +} + func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) { if reqBody == nil { return "", false @@ -3735,8 +4524,8 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p } // normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为: -// 1) store=false 2) stream=true -func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) { +// 1) store=false 2) 非 compact 保持 stream=true;compact 强制 stream=false +func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) { if len(body) == 0 { return body, false, nil } @@ -3744,22 +4533,40 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) { normalized := body changed := false - if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False { - next, err := sjson.SetBytes(normalized, "store", false) - if err != nil { - return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err) + if compact { + if store := gjson.GetBytes(normalized, "store"); store.Exists() { + next, err := sjson.DeleteBytes(normalized, "store") + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body delete store: %w", err) + } + normalized = next + changed = true } - normalized = next - changed = true - } - - if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True { - next, err := sjson.SetBytes(normalized, "stream", true) - if err != nil { - return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err) + if stream := gjson.GetBytes(normalized, "stream"); stream.Exists() { + next, err := sjson.DeleteBytes(normalized, "stream") + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body delete stream: %w", err) + } + normalized = next + changed = true + } + } else { + if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False { + next, err := sjson.SetBytes(normalized, "store", false) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err) + } + normalized = next + changed = true + } + if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True { + next, err := sjson.SetBytes(normalized, "stream", true) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err) + } + normalized = next + changed = true } - normalized = next - changed = true } return normalized, changed, nil @@ -3804,6 +4611,40 @@ func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *s return &value } +func extractOpenAIServiceTier(reqBody map[string]any) *string { + if reqBody == nil { + return nil + } + raw, ok := reqBody["service_tier"].(string) + if !ok { + return nil + } + return normalizeOpenAIServiceTier(raw) +} + +func extractOpenAIServiceTierFromBody(body []byte) *string { + if len(body) == 0 { + return nil + } + return normalizeOpenAIServiceTier(gjson.GetBytes(body, "service_tier").String()) +} + +func normalizeOpenAIServiceTier(raw string) *string { + value := strings.ToLower(strings.TrimSpace(raw)) + if value == "" { + return nil + } + if value == "fast" { + value = "priority" + } + switch value { + case "priority", "flex": + return &value + default: + return nil + } +} + func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) { if c != nil { if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok { @@ -3859,3 +4700,11 @@ func normalizeOpenAIReasoningEffort(raw string) string { return "" } } + +func optionalTrimmedStringPtr(raw string) *string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return nil + } + return &trimmed +} diff --git a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go index d7c95ada..fe58e92f 100644 --- a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go +++ b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go @@ -211,6 +211,26 @@ func TestLogOpenAIInstructionsRequiredDebug_NonTargetErrorSkipped(t *testing.T) require.False(t, logSink.ContainsMessage("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查")) } +func TestIsOpenAITransientProcessingError(t *testing.T) { + require.True(t, isOpenAITransientProcessingError( + http.StatusBadRequest, + "An error occurred while processing your request.", + nil, + )) + + require.True(t, isOpenAITransientProcessingError( + http.StatusBadRequest, + "", + []byte(`{"error":{"message":"An error occurred while processing your request. You can retry your request, or contact us through our help center at help.openai.com if the error persists. Please include the request ID req_123 in your message."}}`), + )) + + require.False(t, isOpenAITransientProcessingError( + http.StatusBadRequest, + "Missing required parameter: 'instructions'", + []byte(`{"error":{"message":"Missing required parameter: 'instructions'"}}`), + )) +} + func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing.T) { gin.SetMode(gin.TestMode) logSink, restore := captureStructuredLog(t) @@ -264,3 +284,51 @@ func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing require.True(t, logSink.ContainsField("request_body_size")) require.False(t, logSink.ContainsField("request_body_preview")) } + +func TestOpenAIGatewayService_Forward_TransientProcessingErrorTriggersFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-processing-400"}, + }, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"An error occurred while processing your request. You can retry your request, or contact us through our help center at help.openai.com if the error persists. Please include the request ID req_123 in your message.","type":"invalid_request_error"}}`)), + }, + } + svc := &OpenAIGatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ForceCodexCLI: false}, + }, + httpUpstream: upstream, + } + account := &Account{ + ID: 1001, + Name: "codex max套餐", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test"}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + body := []byte(`{"model":"gpt-5.1-codex","stream":false,"input":[{"type":"text","text":"hello"}]}`) + + _, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode) + require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request") + require.False(t, c.Writer.Written(), "service 层应返回 failover 错误给上层换号,而不是直接向客户端写响应") +} diff --git a/backend/internal/service/openai_gateway_service_session_isolation_test.go b/backend/internal/service/openai_gateway_service_session_isolation_test.go new file mode 100644 index 00000000..d42fbcc5 --- /dev/null +++ b/backend/internal/service/openai_gateway_service_session_isolation_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsolateOpenAISessionID(t *testing.T) { + t.Run("empty_raw_returns_empty", func(t *testing.T) { + assert.Equal(t, "", isolateOpenAISessionID(1, "")) + assert.Equal(t, "", isolateOpenAISessionID(1, " ")) + }) + + t.Run("deterministic", func(t *testing.T) { + a := isolateOpenAISessionID(42, "sess_abc123") + b := isolateOpenAISessionID(42, "sess_abc123") + assert.Equal(t, a, b) + }) + + t.Run("different_apiKeyID_different_result", func(t *testing.T) { + a := isolateOpenAISessionID(1, "same_session") + b := isolateOpenAISessionID(2, "same_session") + require.NotEqual(t, a, b, "不同 API Key 使用相同 session_id 应产生不同隔离值") + }) + + t.Run("different_raw_different_result", func(t *testing.T) { + a := isolateOpenAISessionID(1, "session_a") + b := isolateOpenAISessionID(1, "session_b") + require.NotEqual(t, a, b) + }) + + t.Run("format_is_16_hex_chars", func(t *testing.T) { + result := isolateOpenAISessionID(99, "test_session") + assert.Len(t, result, 16, "应为 16 字符的 hex 字符串") + for _, ch := range result { + assert.True(t, (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f'), + "应仅包含 hex 字符: %c", ch) + } + }) + + t.Run("zero_apiKeyID_still_works", func(t *testing.T) { + result := isolateOpenAISessionID(0, "session") + assert.NotEmpty(t, result) + // apiKeyID=0 与 apiKeyID=1 应产生不同结果 + other := isolateOpenAISessionID(1, "session") + assert.NotEqual(t, result, other) + }) +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 89443b69..9e2f33f2 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/cespare/xxhash/v2" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -28,6 +29,22 @@ type stubOpenAIAccountRepo struct { accounts []Account } +type snapshotUpdateAccountRepo struct { + stubOpenAIAccountRepo + updateExtraCalls chan map[string]any +} + +func (r *snapshotUpdateAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { + if r.updateExtraCalls != nil { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtraCalls <- copied + } + return nil +} + func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) { for i := range r.accounts { if r.accounts[i].ID == id { @@ -57,6 +74,10 @@ func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, pl return result, nil } +func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return r.ListSchedulableByPlatform(ctx, platform) +} + type stubConcurrencyCache struct { ConcurrencyCache loadBatchErr error @@ -895,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) { } } -func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { +func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErrorEvent(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Gateway: config.GatewayConfig{ @@ -919,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { } _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") - if err != nil { - t.Fatalf("expected nil error, got %v", err) + if err == nil || !strings.Contains(err.Error(), "stream usage incomplete") { + t.Fatalf("expected incomplete stream error, got %v", err) } if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") { t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) @@ -972,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) { } } +func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + }() + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + _ = pr.Close() + if err == nil || !strings.Contains(err.Error(), "missing terminal event") { + t.Fatalf("expected missing terminal event error, got %v", err) + } +} + +func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + }() + + _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now()) + _ = pr.Close() + if err == nil || !strings.Contains(err.Error(), "missing terminal event") { + t.Fatalf("expected missing terminal event error, got %v", err) + } +} + +func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now()) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 2, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + require.Equal(t, 1, result.usage.CacheReadInputTokens) +} + func TestOpenAIStreamingTooLong(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -1103,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { go func() { defer func() { _ = pw.Close() }() - _, _ = pw.Write([]byte("data: {}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{}}\n\n")) }() _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") @@ -1244,8 +1366,157 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) { } } +func TestOpenAIUpdateCodexUsageSnapshotFromHeaders(t *testing.T) { + repo := &snapshotUpdateAccountRepo{updateExtraCalls: make(chan map[string]any, 1)} + svc := &OpenAIGatewayService{accountRepo: repo} + headers := http.Header{} + headers.Set("x-codex-primary-used-percent", "12") + headers.Set("x-codex-secondary-used-percent", "34") + headers.Set("x-codex-primary-window-minutes", "300") + headers.Set("x-codex-secondary-window-minutes", "10080") + headers.Set("x-codex-primary-reset-after-seconds", "600") + headers.Set("x-codex-secondary-reset-after-seconds", "86400") + + svc.UpdateCodexUsageSnapshotFromHeaders(context.Background(), 123, headers) + + select { + case updates := <-repo.updateExtraCalls: + require.Equal(t, 12.0, updates["codex_5h_used_percent"]) + require.Equal(t, 34.0, updates["codex_7d_used_percent"]) + require.Equal(t, 600, updates["codex_5h_reset_after_seconds"]) + require.Equal(t, 86400, updates["codex_7d_reset_after_seconds"]) + case <-time.After(2 * time.Second): + t.Fatal("expected UpdateExtra to be called") + } +} + +func TestOpenAIResponsesRequestPathSuffix(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + tests := []struct { + name string + path string + want string + }{ + {name: "exact v1 responses", path: "/v1/responses", want: ""}, + {name: "compact v1 responses", path: "/v1/responses/compact", want: "/compact"}, + {name: "compact alias responses", path: "/responses/compact/", want: "/compact"}, + {name: "nested suffix", path: "/openai/v1/responses/compact/detail", want: "/compact/detail"}, + {name: "unrelated path", path: "/v1/chat/completions", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil) + require.Equal(t, tt.want, openAIResponsesRequestPathSuffix(c)) + }) + } +} + +func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader([]byte(`{"model":"gpt-5"}`))) + + svc := &OpenAIGatewayService{} + account := &Account{Type: AccountTypeOAuth} + + req, err := svc.buildUpstreamRequestOpenAIPassthrough(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token") + require.NoError(t, err) + require.Equal(t, chatgptCodexURL+"/compact", req.URL.String()) + require.Equal(t, "application/json", req.Header.Get("Accept")) + require.Equal(t, codexCLIVersion, req.Header.Get("Version")) + require.NotEmpty(t, req.Header.Get("Session_Id")) +} + +func TestOpenAIBuildUpstreamRequestCompactForcesJSONAcceptForOAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader([]byte(`{"model":"gpt-5"}`))) + + svc := &OpenAIGatewayService{} + account := &Account{ + Type: AccountTypeOAuth, + Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"}, + } + + req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", true) + require.NoError(t, err) + require.Equal(t, chatgptCodexURL+"/compact", req.URL.String()) + require.Equal(t, "application/json", req.Header.Get("Accept")) + require.Equal(t, codexCLIVersion, req.Header.Get("Version")) + require.NotEmpty(t, req.Header.Get("Session_Id")) +} + +func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", bytes.NewReader([]byte(`{"model":"gpt-5"}`))) + + svc := &OpenAIGatewayService{cfg: &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + }} + account := &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{"base_url": "https://example.com/v1"}, + } + + req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", false) + require.NoError(t, err) + require.Equal(t, "https://example.com/v1/responses/compact", req.URL.String()) +} + +func TestOpenAIBuildUpstreamRequestOAuthOfficialClientOriginatorCompatibility(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + userAgent string + originator string + wantOriginator string + }{ + {name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"}, + {name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"}, + {name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader([]byte(`{"model":"gpt-5"}`))) + if tt.userAgent != "" { + c.Request.Header.Set("User-Agent", tt.userAgent) + } + if tt.originator != "" { + c.Request.Header.Set("originator", tt.originator) + } + + svc := &OpenAIGatewayService{} + account := &Account{ + Type: AccountTypeOAuth, + Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"}, + } + + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) + req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", isCodexCLI) + require.NoError(t, err) + require.Equal(t, tt.wantOriginator, req.Header.Get("originator")) + }) + } +} + // ==================== P1-08 修复:model 替换性能优化测试 ==================== +// ==================== P1-08 修复:model 替换性能优化测试 ============= func TestReplaceModelInSSELine(t *testing.T) { svc := &OpenAIGatewayService{} @@ -1504,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) { require.Equal(t, 3, usage.InputTokens) require.Equal(t, 5, usage.OutputTokens) require.Equal(t, 2, usage.CacheReadInputTokens) + + // done 事件同样可能携带最终 usage + svc.parseSSEUsage(`{"type":"response.done","response":{"usage":{"input_tokens":13,"output_tokens":15,"input_tokens_details":{"cached_tokens":4}}}}`, usage) + require.Equal(t, 13, usage.InputTokens) + require.Equal(t, 15, usage.OutputTokens) + require.Equal(t, 4, usage.CacheReadInputTokens) } func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) { @@ -1572,3 +1849,27 @@ func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream") require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`) } + +func TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.failed","error":{"message":"upstream rejected request"}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + require.Nil(t, usage) + require.Error(t, err) + require.Equal(t, http.StatusBadGateway, rec.Code) + require.Contains(t, rec.Body.String(), "upstream rejected request") + require.Contains(t, rec.Header().Get("Content-Type"), "application/json") +} diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go new file mode 100644 index 00000000..9bf3fba3 --- /dev/null +++ b/backend/internal/service/openai_model_mapping.go @@ -0,0 +1,19 @@ +package service + +// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible +// forwarding. Group-level default mapping only applies when the account itself +// did not match any explicit model_mapping rule. +func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { + if account == nil { + if defaultMappedModel != "" { + return defaultMappedModel + } + return requestedModel + } + + mappedModel, matched := account.ResolveMappedModel(requestedModel) + if !matched && defaultMappedModel != "" { + return defaultMappedModel + } + return mappedModel +} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go new file mode 100644 index 00000000..7af3ecae --- /dev/null +++ b/backend/internal/service/openai_model_mapping_test.go @@ -0,0 +1,70 @@ +package service + +import "testing" + +func TestResolveOpenAIForwardModel(t *testing.T) { + tests := []struct { + name string + account *Account + requestedModel string + defaultMappedModel string + expectedModel string + }{ + { + name: "falls back to group default when account has no mapping", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt-5.4", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-4o-mini", + }, + { + name: "preserves exact passthrough mapping instead of group default", + account: &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + }, + }, + }, + requestedModel: "gpt-5.4", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-5.4", + }, + { + name: "preserves wildcard passthrough mapping instead of group default", + account: &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-*": "gpt-5.4", + }, + }, + }, + requestedModel: "gpt-5.4", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-5.4", + }, + { + name: "uses account remap when explicit target differs", + account: &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5": "gpt-5.4", + }, + }, + }, + requestedModel: "gpt-5", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-5.4", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := resolveOpenAIForwardModel(tt.account, tt.requestedModel, tt.defaultMappedModel); got != tt.expectedModel { + t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", got, tt.expectedModel) + } + }) + } +} diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 0840d3b1..f51a7491 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -236,6 +236,60 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali require.NotContains(t, body, "\"name\":\"edit\"") } +func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("Content-Type", "application/json") + + originalBody := []byte(`{"model":"gpt-5.1-codex","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"compact me"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"cmp_123","usage":{"input_tokens":11,"output_tokens":22}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.Stream) + + require.False(t, gjson.GetBytes(upstream.lastBody, "store").Exists()) + require.False(t, gjson.GetBytes(upstream.lastBody, "stream").Exists()) + require.Equal(t, "gpt-5.1-codex", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "compact me", gjson.GetBytes(upstream.lastBody, "input.0.text").String()) + require.Equal(t, "local-test-instructions", strings.TrimSpace(gjson.GetBytes(upstream.lastBody, "instructions").String())) + require.Equal(t, "application/json", upstream.lastReq.Header.Get("Accept")) + require.Equal(t, codexCLIVersion, upstream.lastReq.Header.Get("Version")) + require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id")) + require.Equal(t, "chatgpt.com", upstream.lastReq.Host) + require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id")) + require.Contains(t, rec.Body.String(), `"id":"cmp_123"`) +} + func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { gin.SetMode(gin.TestMode) logSink, restore := captureStructuredLog(t) @@ -385,7 +439,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) headers := make(http.Header) headers.Set("Content-Type", "application/json") @@ -399,7 +453,14 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes resp := &http.Response{ StatusCode: http.StatusOK, Header: headers, - Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"h"}`, + "", + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), } upstream := &httpUpstreamRecorder{resp: resp} @@ -617,7 +678,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") - originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"service_tier":"fast","input":[{"type":"text","text":"hi"}]}`) upstreamSSE := strings.Join([]string{ `data: {"type":"response.output_text.delta","delta":"h"}`, @@ -657,6 +718,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test require.GreaterOrEqual(t, time.Since(start), time.Duration(0)) require.NotNil(t, result.FirstTokenMs) require.GreaterOrEqual(t, *result.FirstTokenMs, 0) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "priority", *result.ServiceTier) } func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollectsUsage(t *testing.T) { @@ -723,7 +786,7 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd c.Request.Header.Set("User-Agent", "curl/8.0") c.Request.Header.Set("X-Test", "keep") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"service_tier":"flex","max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`) resp := &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, @@ -749,8 +812,11 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd RateMultiplier: f64p(1), } - _, err := svc.Forward(context.Background(), c, account, originalBody) + result, err := svc.Forward(context.Background(), c, account, originalBody) require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "flex", *result.ServiceTier) require.NotNil(t, upstream.lastReq) require.Equal(t, originalBody, upstream.lastBody) require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String()) @@ -836,7 +902,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_InfoWhenStreamEndsWithoutDone(t * } _, err := svc.Forward(context.Background(), c, account, originalBody) - require.NoError(t, err) + require.EqualError(t, err, "stream usage incomplete: missing terminal event") require.True(t, logSink.ContainsMessage("上游流在未收到 [DONE] 时结束,疑似断流")) require.True(t, logSink.ContainsMessageAtLevel("上游流在未收到 [DONE] 时结束,疑似断流", "info")) require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-truncate")) @@ -852,11 +918,16 @@ func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *t c.Request.Header.Set("x-stainless-timeout", "120000") c.Request.Header.Set("X-Test", "keep") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) resp := &http.Response{ StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-default"}}, - Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-filter-default"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), } upstream := &httpUpstreamRecorder{resp: resp} svc := &OpenAIGatewayService{ @@ -893,11 +964,16 @@ func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured c.Request.Header.Set("x-stainless-timeout", "120000") c.Request.Header.Set("X-Test", "keep") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) resp := &http.Response{ StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-allow"}}, - Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-filter-allow"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), } upstream := &httpUpstreamRecorder{resp: resp} svc := &OpenAIGatewayService{ diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index 07cb5472..bd82e107 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -7,7 +7,6 @@ import ( "io" "log/slog" "net/http" - "net/url" "regexp" "sort" "strconv" @@ -15,6 +14,7 @@ import ( "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) @@ -130,6 +130,7 @@ type OpenAITokenInfo struct { ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"` ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` OrganizationID string `json:"organization_id,omitempty"` + PlanType string `json:"plan_type,omitempty"` } // ExchangeCode exchanges authorization code for tokens @@ -202,6 +203,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID tokenInfo.OrganizationID = userInfo.OrganizationID + tokenInfo.PlanType = userInfo.PlanType } return tokenInfo, nil @@ -246,6 +248,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID tokenInfo.OrganizationID = userInfo.OrganizationID + tokenInfo.PlanType = userInfo.PlanType } return tokenInfo, nil @@ -273,7 +276,13 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi req.Header.Set("Referer", "https://sora.chatgpt.com/") req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - client := newOpenAIOAuthHTTPClient(proxyURL) + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: 120 * time.Second, + }) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err) + } resp, err := client.Do(req) if err != nil { return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) @@ -504,6 +513,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) if tokenInfo.OrganizationID != "" { creds["organization_id"] = tokenInfo.OrganizationID } + if tokenInfo.PlanType != "" { + creds["plan_type"] = tokenInfo.PlanType + } if strings.TrimSpace(tokenInfo.ClientID) != "" { creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID) } @@ -530,19 +542,6 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64 return proxy.URL(), nil } -func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client { - transport := &http.Transport{} - if strings.TrimSpace(proxyURL) != "" { - if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" { - transport.Proxy = http.ProxyURL(parsed) - } - } - return &http.Client{ - Timeout: 120 * time.Second, - Transport: transport, - } -} - func normalizeOpenAIOAuthPlatform(platform string) string { switch strings.ToLower(strings.TrimSpace(platform)) { case PlatformSora: diff --git a/backend/internal/service/openai_privacy_service.go b/backend/internal/service/openai_privacy_service.go new file mode 100644 index 00000000..90cd522d --- /dev/null +++ b/backend/internal/service/openai_privacy_service.go @@ -0,0 +1,77 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/imroc/req/v3" +) + +// PrivacyClientFactory creates an HTTP client for privacy API calls. +// Injected from repository layer to avoid import cycles. +type PrivacyClientFactory func(proxyURL string) (*req.Client, error) + +const ( + openAISettingsURL = "https://chatgpt.com/backend-api/settings/account_user_setting" + + PrivacyModeTrainingOff = "training_off" + PrivacyModeFailed = "training_set_failed" + PrivacyModeCFBlocked = "training_set_cf_blocked" +) + +// disableOpenAITraining calls ChatGPT settings API to turn off "Improve the model for everyone". +// Returns privacy_mode value: "training_off" on success, "cf_blocked" / "failed" on failure. +func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFactory, accessToken, proxyURL string) string { + if accessToken == "" || clientFactory == nil { + return "" + } + + ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + client, err := clientFactory(proxyURL) + if err != nil { + slog.Warn("openai_privacy_client_error", "error", err.Error()) + return PrivacyModeFailed + } + + resp, err := client.R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+accessToken). + SetHeader("Origin", "https://chatgpt.com"). + SetHeader("Referer", "https://chatgpt.com/"). + SetQueryParam("feature", "training_allowed"). + SetQueryParam("value", "false"). + Patch(openAISettingsURL) + + if err != nil { + slog.Warn("openai_privacy_request_error", "error", err.Error()) + return PrivacyModeFailed + } + + if resp.StatusCode == 403 || resp.StatusCode == 503 { + body := resp.String() + if strings.Contains(body, "cloudflare") || strings.Contains(body, "cf-") || strings.Contains(body, "Just a moment") { + slog.Warn("openai_privacy_cf_blocked", "status", resp.StatusCode) + return PrivacyModeCFBlocked + } + } + + if !resp.IsSuccessState() { + slog.Warn("openai_privacy_failed", "status", resp.StatusCode, "body", truncate(resp.String(), 200)) + return PrivacyModeFailed + } + + slog.Info("openai_privacy_training_disabled") + return PrivacyModeTrainingOff +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + fmt.Sprintf("...(%d more)", len(s)-n) +} diff --git a/backend/internal/service/openai_sticky_compat.go b/backend/internal/service/openai_sticky_compat.go index e897debc..fe0f1309 100644 --- a/backend/internal/service/openai_sticky_compat.go +++ b/backend/internal/service/openai_sticky_compat.go @@ -29,6 +29,13 @@ func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit, openAIStickyLegacyDualWriteTotal.Load() } +// DeriveSessionHashFromSeed computes the current-format sticky-session hash +// from an arbitrary seed string. +func DeriveSessionHashFromSeed(seed string) string { + currentHash, _ := deriveOpenAISessionHashes(seed) + return currentHash +} + func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) { normalized := strings.TrimSpace(sessionID) if normalized == "" { diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index a8a6b96c..69477ce7 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -20,7 +20,7 @@ const ( openAILockWarnThresholdMs = 250 ) -// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。 +// OpenAITokenRuntimeMetrics is a snapshot of refresh and lock contention metrics. type OpenAITokenRuntimeMetrics struct { RefreshRequests int64 RefreshSuccess int64 @@ -72,15 +72,18 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() { m.lastObservedUnixMs.Store(time.Now().UnixMilli()) } -// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) +// OpenAITokenCache token cache interface. type OpenAITokenCache = GeminiTokenCache -// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token +// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts. type OpenAITokenProvider struct { accountRepo AccountRepository tokenCache OpenAITokenCache openAIOAuthService *OpenAIOAuthService metrics *openAITokenRuntimeMetricsStore + refreshAPI *OAuthRefreshAPI + executor OAuthRefreshExecutor + refreshPolicy ProviderRefreshPolicy } func NewOpenAITokenProvider( @@ -93,9 +96,21 @@ func NewOpenAITokenProvider( tokenCache: tokenCache, openAIOAuthService: openAIOAuthService, metrics: &openAITokenRuntimeMetricsStore{}, + refreshPolicy: OpenAIProviderRefreshPolicy(), } } +// SetRefreshAPI injects unified OAuth refresh API and executor. +func (p *OpenAITokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) { + p.refreshAPI = api + p.executor = executor +} + +// SetRefreshPolicy injects caller-side refresh policy. +func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) { + p.refreshPolicy = policy +} + func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics { if p == nil { return OpenAITokenRuntimeMetrics{} @@ -110,7 +125,7 @@ func (p *OpenAITokenProvider) ensureMetrics() { } } -// GetAccessToken 获取有效的 access_token +// GetAccessToken returns a valid access_token. func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { p.ensureMetrics() if account == nil { @@ -122,7 +137,7 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou cacheKey := OpenAITokenCacheKey(account) - // 1. 先尝试缓存 + // 1) Try cache first. if p.tokenCache != nil { if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { slog.Debug("openai_token_cache_hit", "account_id", account.ID) @@ -134,114 +149,62 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou slog.Debug("openai_token_cache_miss", "account_id", account.ID) - // 2. 如果即将过期则刷新 + // 2) Refresh if needed (pre-expiry skew). expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew refreshFailed := false - if needsRefresh && p.tokenCache != nil { + + if needsRefresh && p.refreshAPI != nil && p.executor != nil { + p.metrics.refreshRequests.Add(1) + p.metrics.touchNow() + + // Sora accounts skip OpenAI OAuth refresh and keep existing token path. + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID) + refreshFailed = true + } else { + result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew) + if err != nil { + if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { + return "", err + } + slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) + p.metrics.refreshFailure.Add(1) + refreshFailed = true + } else if result.LockHeld { + if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache { + p.metrics.lockContention.Add(1) + p.metrics.touchNow() + token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey) + if waitErr != nil { + return "", waitErr + } + if strings.TrimSpace(token) != "" { + slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) + return token, nil + } + } + } else if result.Refreshed { + p.metrics.refreshSuccess.Add(1) + account = result.Account + expiresAt = account.GetCredentialAsTime("expires_at") + } else { + account = result.Account + expiresAt = account.GetCredentialAsTime("expires_at") + } + } + } else if needsRefresh && p.tokenCache != nil { + // Backward-compatible test path when refreshAPI is not injected. p.metrics.refreshRequests.Add(1) p.metrics.touchNow() locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) if lockErr == nil && locked { defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() - - // 拿到锁后再次检查缓存(另一个 worker 可能已刷新) - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { - return token, nil - } - - // 从数据库获取最新账户信息 - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh - } - expiresAt = account.GetCredentialAsTime("expires_at") - if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if account.Platform == PlatformSora { - slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID) - // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 - refreshFailed = true - } else if p.openAIOAuthService == nil { - slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) - p.metrics.refreshFailure.Add(1) - refreshFailed = true // 无法刷新,标记失败 - } else { - tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) - if err != nil { - // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token - slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) - p.metrics.refreshFailure.Add(1) - refreshFailed = true // 刷新失败,标记以使用短 TTL - } else { - p.metrics.refreshSuccess.Add(1) - newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } - account.Credentials = newCredentials - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr) - } - expiresAt = account.GetCredentialAsTime("expires_at") - } - } - } } else if lockErr != nil { - // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时) p.metrics.lockAcquireFailure.Add(1) p.metrics.touchNow() - slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr) - - // 检查 ctx 是否已取消 - if ctx.Err() != nil { - return "", ctx.Err() - } - - // 从数据库获取最新账户信息 - if p.accountRepo != nil { - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh - } - } - expiresAt = account.GetCredentialAsTime("expires_at") - - // 仅在 expires_at 已过期/接近过期时才执行无锁刷新 - if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if account.Platform == PlatformSora { - slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID) - // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 - refreshFailed = true - } else if p.openAIOAuthService == nil { - slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) - p.metrics.refreshFailure.Add(1) - refreshFailed = true - } else { - tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) - if err != nil { - slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err) - p.metrics.refreshFailure.Add(1) - refreshFailed = true - } else { - p.metrics.refreshSuccess.Add(1) - newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } - account.Credentials = newCredentials - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr) - } - expiresAt = account.GetCredentialAsTime("expires_at") - } - } - } + slog.Warn("openai_token_lock_failed", "account_id", account.ID, "error", lockErr) } else { - // 锁被其他 worker 持有:使用短轮询+jitter,降低固定等待导致的尾延迟台阶。 p.metrics.lockContention.Add(1) p.metrics.touchNow() token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey) @@ -260,22 +223,23 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou return "", errors.New("access_token not found in credentials") } - // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) + // 3) Populate cache with TTL. if p.tokenCache != nil { latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) if isStale && latestAccount != nil { - // 版本过时,使用 DB 中的最新 token slog.Debug("openai_token_version_stale_use_latest", "account_id", account.ID) accessToken = latestAccount.GetOpenAIAccessToken() if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found after version check") } - // 不写入缓存,让下次请求重新处理 } else { ttl := 30 * time.Minute if refreshFailed { - // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动 - ttl = time.Minute + if p.refreshPolicy.FailureTTL > 0 { + ttl = p.refreshPolicy.FailureTTL + } else { + ttl = time.Minute + } slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed") } else if expiresAt != nil { until := time.Until(*expiresAt) diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go index 3fe08179..9a8803d3 100644 --- a/backend/internal/service/openai_ws_account_sticky_test.go +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -48,6 +48,43 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T } } +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + account := Account{ + ID: 12, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + RateLimitResetAt: &rateLimitedUntil, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_rl", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil) + require.NoError(t, err) + require.Nil(t, selection, "限额中的账号不应继续命中 previous_response_id 粘连") + boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_rl") + require.NoError(t, getErr) + require.Zero(t, boundAccountID) +} + func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { ctx := context.Background() groupID := int64(23) diff --git a/backend/internal/service/openai_ws_client.go b/backend/internal/service/openai_ws_client.go index 9f3c47b7..80b75530 100644 --- a/backend/internal/service/openai_ws_client.go +++ b/backend/internal/service/openai_ws_client.go @@ -11,6 +11,7 @@ import ( "sync/atomic" "time" + openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2" coderws "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) @@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct { conn *coderws.Conn } +var _ openaiwsv2.FrameConn = (*coderOpenAIWSClientConn)(nil) + func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error { if c == nil || c.conn == nil { return errOpenAIWSConnClosed @@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro } } +func (c *coderOpenAIWSClientConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.conn == nil { + return coderws.MessageText, nil, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + msgType, payload, err := c.conn.Read(ctx) + if err != nil { + return coderws.MessageText, nil, err + } + return msgType, payload, nil +} + +func (c *coderOpenAIWSClientConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Write(ctx, msgType, payload) +} + func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error { if c == nil || c.conn == nil { return errOpenAIWSConnClosed diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 74ba472f..1d3d8fdf 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -46,9 +46,10 @@ const ( openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024 openAIWSPayloadSizeEstimateMaxItems = 16 - openAIWSEventFlushBatchSizeDefault = 4 - openAIWSEventFlushIntervalDefault = 25 * time.Millisecond - openAIWSPayloadLogSampleDefault = 0.2 + openAIWSEventFlushBatchSizeDefault = 4 + openAIWSEventFlushIntervalDefault = 25 * time.Millisecond + openAIWSPayloadLogSampleDefault = 0.2 + openAIWSPassthroughIdleTimeoutDefault = time.Hour openAIWSStoreDisabledConnModeStrict = "strict" openAIWSStoreDisabledConnModeAdaptive = "adaptive" @@ -863,7 +864,8 @@ func isOpenAIWSClientDisconnectError(err error) bool { strings.Contains(message, "unexpected eof") || strings.Contains(message, "use of closed network connection") || strings.Contains(message, "connection reset by peer") || - strings.Contains(message, "broken pipe") + strings.Contains(message, "broken pipe") || + strings.Contains(message, "an established connection was aborted") } func classifyOpenAIWSReadFallbackReason(err error) string { @@ -904,6 +906,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool { return s.openaiWSPool } +func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer { + if s == nil { + return nil + } + s.openaiWSPassthroughDialerOnce.Do(func() { + if s.openaiWSPassthroughDialer == nil { + s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer() + } + }) + return s.openaiWSPassthroughDialer +} + func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot { pool := s.getOpenAIWSConnPool() if pool == nil { @@ -967,6 +981,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration { return 15 * time.Minute } +func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration { + if timeout := s.openAIWSReadTimeout(); timeout > 0 { + return timeout + } + return openAIWSPassthroughIdleTimeoutDefault +} + func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second @@ -1103,11 +1124,22 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders( headers.Set("accept-language", v) } } - if sessionResolution.SessionID != "" { - headers.Set("session_id", sessionResolution.SessionID) - } - if sessionResolution.ConversationID != "" { - headers.Set("conversation_id", sessionResolution.ConversationID) + // OAuth 账号:将 apiKeyID 混入 session 标识符,防止跨用户会话碰撞。 + if account != nil && account.Type == AccountTypeOAuth { + apiKeyID := getAPIKeyIDFromContext(c) + if sessionResolution.SessionID != "" { + headers.Set("session_id", isolateOpenAISessionID(apiKeyID, sessionResolution.SessionID)) + } + if sessionResolution.ConversationID != "" { + headers.Set("conversation_id", isolateOpenAISessionID(apiKeyID, sessionResolution.ConversationID)) + } + } else { + if sessionResolution.SessionID != "" { + headers.Set("session_id", sessionResolution.SessionID) + } + if sessionResolution.ConversationID != "" { + headers.Set("conversation_id", sessionResolution.ConversationID) + } } if state := strings.TrimSpace(turnState); state != "" { headers.Set(openAIWSTurnStateHeader, state) @@ -1120,11 +1152,7 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders( if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { headers.Set("chatgpt-account-id", chatgptAccountID) } - if isCodexCLI { - headers.Set("originator", "codex_cli_rs") - } else { - headers.Set("originator", "opencode") - } + headers.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) } betaValue := openAIWSBetaV2Value @@ -1836,9 +1864,22 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( wsPath, account.ProxyID != nil && account.Proxy != nil, ) + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { + s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error())) + } return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err) } - defer lease.Release() + // cleanExit 标记正常终端事件退出,此时上游不会再发送帧,连接可安全归还复用。 + // 所有异常路径(读写错误、error 事件等)已在各自分支中提前调用 MarkBroken, + // 因此 defer 中只需处理正常退出时不 MarkBroken 即可。 + cleanExit := false + defer func() { + if !cleanExit { + lease.MarkBroken() + } + lease.Release() + }() connID := strings.TrimSpace(lease.ConnID()) logOpenAIWSModeDebug( "connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v", @@ -2119,6 +2160,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw) errMsg := strings.TrimSpace(errMsgRaw) if errMsg == "" { errMsg = "Upstream websocket error" @@ -2215,6 +2257,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( } if isTerminalEvent { + cleanExit = true break } } @@ -2285,9 +2328,11 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( RequestID: responseID, Usage: *usage, Model: originalModel, + ServiceTier: extractOpenAIServiceTier(reqBody), ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), Stream: reqStream, OpenAIWSMode: true, + ResponseHeaders: lease.HandshakeHeaders(), Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil @@ -2322,7 +2367,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled - ingressMode := OpenAIWSIngressModeShared + ingressMode := OpenAIWSIngressModeCtxPool if modeRouterV2Enabled { ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault) if ingressMode == OpenAIWSIngressModeOff { @@ -2332,6 +2377,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( nil, ) } + switch ingressMode { + case OpenAIWSIngressModePassthrough: + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) + } + return s.proxyResponsesWebSocketV2Passthrough( + ctx, + c, + clientConn, + account, + token, + firstClientMessage, + hooks, + wsDecision, + ) + case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + // continue + default: + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode only supports ctx_pool/passthrough", + nil, + ) + } } if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) @@ -2497,7 +2566,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } } - isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey) baseAcquireReq := openAIWSAcquireRequest{ Account: account, @@ -2597,6 +2666,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( wsPath, account.ProxyID != nil && account.Proxy != nil, ) + var dialErr *openAIWSDialError + if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { + s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error())) + } if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { return nil, NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, @@ -2735,6 +2808,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), upstreamMessage, errCodeRaw, errTypeRaw, errMsgRaw) fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound && @@ -2871,9 +2945,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( RequestID: responseID, Usage: usage, Model: originalModel, + ServiceTier: extractOpenAIServiceTierFromBody(payload), ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), Stream: reqStream, OpenAIWSMode: true, + ResponseHeaders: lease.HandshakeHeaders(), Duration: time.Since(turnStart), FirstTokenMs: firstTokenMs, }, nil @@ -2917,12 +2993,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( pinnedSessionConnID = connID } } + // lastTurnClean 标记最后一轮 sendAndRelay 是否正常完成(收到终端事件且客户端未断连)。 + // 所有异常路径(读写错误、error 事件、客户端断连)已在各自分支或上层(L3403)中 MarkBroken, + // 因此 releaseSessionLease 中只需在非正常结束时 MarkBroken。 + lastTurnClean := false releaseSessionLease := func() { if sessionLease == nil { return } - if dedicatedMode { - // dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。 + if !lastTurnClean { sessionLease.MarkBroken() } unpinSessionConn(sessionConnID) @@ -3317,6 +3396,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel) if relayErr != nil { + lastTurnClean = false if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { continue } @@ -3336,6 +3416,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( turnRetry = 0 turnPrevRecoveryTried = false lastTurnFinishedAt = time.Now() + lastTurnClean = true if hooks != nil && hooks.AfterTurn != nil { hooks.AfterTurn(turn, result, nil) } @@ -3561,6 +3642,7 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw) errMsg := strings.TrimSpace(errMsgRaw) if errMsg == "" { errMsg = "OpenAI websocket prewarm error" @@ -3755,7 +3837,7 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { return nil, nil } - if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() { + if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() || !account.IsSchedulable() { _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) return nil, nil } @@ -3824,6 +3906,36 @@ func classifyOpenAIWSAcquireError(err error) string { return "acquire_conn" } +func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") { + return true + } + if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") { + return true + } + if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") { + return true + } + if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) { + return true + } + return false +} + +func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) { + if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI { + return + } + if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { + return + } + s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody) +} + func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { code := strings.ToLower(strings.TrimSpace(codeRaw)) errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) @@ -3836,9 +3948,14 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri return "ws_unsupported", true case "websocket_connection_limit_reached": return "ws_connection_limit_reached", true + case "invalid_encrypted_content": + return "invalid_encrypted_content", true case "previous_response_not_found": return "previous_response_not_found", true } + if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { + return "upstream_rate_limited", false + } if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { return "upgrade_required", true } @@ -3851,6 +3968,10 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") { return "ws_connection_limit_reached", true } + if strings.Contains(msg, "invalid_encrypted_content") || + (strings.Contains(msg, "encrypted content") && strings.Contains(msg, "could not be verified")) { + return "invalid_encrypted_content", true + } if strings.Contains(msg, "previous_response_not_found") || (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) { return "previous_response_not_found", true @@ -3875,6 +3996,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { case strings.Contains(errType, "invalid_request"), strings.Contains(code, "invalid_request"), strings.Contains(code, "bad_request"), + code == "invalid_encrypted_content", code == "previous_response_not_found": return http.StatusBadRequest case strings.Contains(errType, "authentication"), @@ -3884,9 +4006,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { case strings.Contains(errType, "permission"), strings.Contains(code, "forbidden"): return http.StatusForbidden - case strings.Contains(errType, "rate_limit"), - strings.Contains(code, "rate_limit"), - strings.Contains(code, "insufficient_quota"): + case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""): return http.StatusTooManyRequests default: return http.StatusBadGateway diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index 5a3c12c3..c527f2eb 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式") require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式") - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + _ = clientConn.Close(coderws.StatusNormalClosure, "done") select { case serverErr := <-serverErrCh: @@ -298,6 +298,142 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接") } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeRelaysByCaddyAdapter(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + upstreamConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_passthrough_turn_1","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":3}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: upstreamConn} + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPassthroughDialer: captureDialer, + } + + account := &Account{ + ID: 452, + Name: "openai-ingress-passthrough", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }, + } + + serverErrCh := make(chan error, 1) + resultCh := make(chan *OpenAIForwardResult, 1) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { + if turnErr == nil && result != nil { + resultCh <- result + } + }, + } + + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String()) + require.Equal(t, "resp_passthrough_turn_1", gjson.GetBytes(event, "response.id").String()) + _ = clientConn.Close(coderws.StatusNormalClosure, "done") + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 passthrough websocket 结束超时") + } + + select { + case result := <-resultCh: + require.Equal(t, "resp_passthrough_turn_1", result.RequestID) + require.True(t, result.OpenAIWSMode) + require.Equal(t, 2, result.Usage.InputTokens) + require.Equal(t, 3, result.Usage.OutputTokens) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "priority", *result.ServiceTier) + case <-time.After(2 * time.Second): + t.Fatal("未收到 passthrough turn 结果回调") + } + + require.Equal(t, 1, captureDialer.DialCount(), "passthrough 模式应直接建立上游 websocket") + require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create") +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) { gin.SetMode(gin.TestMode) @@ -2459,7 +2595,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnect require.NoError(t, err) writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) - err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false}`)) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false,"service_tier":"flex"}`)) cancelWrite() require.NoError(t, err) // 立即关闭客户端,模拟客户端在 relay 期间断连。 @@ -2477,6 +2613,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnect require.Equal(t, "resp_ingress_disconnect", result.RequestID) require.Equal(t, 2, result.Usage.InputTokens) require.Equal(t, 1, result.Usage.OutputTokens) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "flex", *result.ServiceTier) case <-time.After(2 * time.Second): t.Fatal("未收到断连后的 turn 结果回调") } diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go index 592801f6..7a76c385 100644 --- a/backend/internal/service/openai_ws_forwarder_success_test.go +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/stretchr/testify/require" @@ -379,7 +380,8 @@ func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) { require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_")) } - require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链") + // 条件式 MarkBroken:正常终端事件退出后连接归还复用,不再无条件销毁。 + require.Equal(t, int64(1), upgradeCount.Load(), "正常完成后连接应归还复用,不应每次新建") metrics := svc.SnapshotOpenAIWSPoolMetrics() require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1)) require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1)) @@ -453,8 +455,90 @@ func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段") require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true") require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta")) - require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id")) - require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id")) + // OAuth 账号的 session_id/conversation_id 应被 isolateOpenAISessionID 隔离, + // 测试中未设置 api_key 到 context,apiKeyID=0。 + require.Equal(t, isolateOpenAISessionID(0, "sess-oauth-1"), captureDialer.lastHeaders.Get("session_id")) + require.Equal(t, isolateOpenAISessionID(0, "conv-oauth-1"), captureDialer.lastHeaders.Get("conversation_id")) +} + +func TestOpenAIGatewayService_Forward_WSv2_OAuthOriginatorCompatibility(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + userAgent string + originator string + wantOriginator string + }{ + {name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"}, + {name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"}, + {name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + if tt.userAgent != "" { + c.Request.Header.Set("User-Agent", tt.userAgent) + } + if tt.originator != "" { + c.Request.Header.Set("originator", tt.originator) + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.AllowStoreRecovery = false + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_oauth_originator","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + account := &Account{ + ID: 129, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token-1", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, tt.wantOriginator, captureDialer.lastHeaders.Get("originator")) + }) + } } func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) { @@ -515,7 +599,8 @@ func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheK require.NotNil(t, result) require.Equal(t, "resp_prompt_cache_key", result.RequestID) - require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id")) + // OAuth 账号的 session_id 应被 isolateOpenAISessionID 隔离(apiKeyID=0,未在 context 设置)。 + require.Equal(t, isolateOpenAISessionID(0, "pcache_123"), captureDialer.lastHeaders.Get("session_id")) require.Empty(t, captureDialer.lastHeaders.Get("conversation_id")) require.NotNil(t, captureConn.lastWrite) require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists()) @@ -880,6 +965,10 @@ func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *t require.NotNil(t, result1) require.Equal(t, "resp_meta_1", result1.RequestID) + require.Len(t, captureConn.writes, 1) + firstWrite := requestToJSONString(captureConn.writes[0]) + require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String()) + rec2 := httptest.NewRecorder() c2, _ := gin.CreateTestContext(rec2) c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) @@ -893,7 +982,7 @@ func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *t require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接") require.Len(t, captureConn.writes, 2) - firstWrite := requestToJSONString(captureConn.writes[0]) + firstWrite = requestToJSONString(captureConn.writes[0]) secondWrite := requestToJSONString(captureConn.writes[1]) require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String()) require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String()) @@ -1282,6 +1371,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) { return event, nil } +func (c *openAIWSCaptureConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + payload, err := c.ReadMessage(ctx) + if err != nil { + return coderws.MessageText, nil, err + } + return coderws.MessageText, payload, nil +} + +func (c *openAIWSCaptureConn) WriteFrame(ctx context.Context, _ coderws.MessageType, payload []byte) error { + return c.WriteJSON(ctx, json.RawMessage(payload)) +} + func (c *openAIWSCaptureConn) Ping(ctx context.Context) error { _ = ctx return nil diff --git a/backend/internal/service/openai_ws_pool.go b/backend/internal/service/openai_ws_pool.go index db6a96a7..5950e028 100644 --- a/backend/internal/service/openai_ws_pool.go +++ b/backend/internal/service/openai_ws_pool.go @@ -126,6 +126,13 @@ func (l *openAIWSConnLease) HandshakeHeader(name string) string { return l.conn.handshakeHeader(name) } +func (l *openAIWSConnLease) HandshakeHeaders() http.Header { + if l == nil || l.conn == nil { + return nil + } + return cloneHeader(l.conn.handshakeHeaders) +} + func (l *openAIWSConnLease) IsPrewarmed() bool { if l == nil || l.conn == nil { return false diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index df4d4871..76c66f2f 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -1,6 +1,7 @@ package service import ( + "bytes" "context" "encoding/json" "io" @@ -19,6 +20,47 @@ import ( "github.com/tidwall/gjson" ) +type httpUpstreamSequenceRecorder struct { + mu sync.Mutex + bodies [][]byte + reqs []*http.Request + + responses []*http.Response + errs []error + callCount int +} + +func (u *httpUpstreamSequenceRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + u.mu.Lock() + defer u.mu.Unlock() + + idx := u.callCount + u.callCount++ + u.reqs = append(u.reqs, req) + if req != nil && req.Body != nil { + b, _ := io.ReadAll(req.Body) + u.bodies = append(u.bodies, b) + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewReader(b)) + } else { + u.bodies = append(u.bodies, nil) + } + if idx < len(u.errs) && u.errs[idx] != nil { + return nil, u.errs[idx] + } + if idx < len(u.responses) { + return u.responses[idx], nil + } + if len(u.responses) == 0 { + return nil, nil + } + return u.responses[len(u.responses)-1], nil +} + +func (u *httpUpstreamSequenceRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) { gin.SetMode(gin.TestMode) wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -143,6 +185,176 @@ func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testi require.Equal(t, "client_protocol_http", reason) } +func TestOpenAIGatewayService_Forward_HTTPIngressRetriesInvalidEncryptedContentOnce(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) + + upstream := &httpUpstreamSequenceRecorder{ + responses: []*http.Response{ + { + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"error":{"code":"invalid_encrypted_content","type":"invalid_request_error","message":"The encrypted content could not be verified."}}`, + )), + }, + { + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"id":"resp_http_retry_ok","usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 102, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_retry","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me"}]},{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发") + require.Equal(t, 2, upstream.callCount, "命中 invalid_encrypted_content 后应只在 HTTP 路径重试一次") + require.Len(t, upstream.bodies, 2) + + firstBody := upstream.bodies[0] + secondBody := upstream.bodies[1] + require.False(t, gjson.GetBytes(firstBody, "previous_response_id").Exists(), "HTTP 首次请求仍应沿用原逻辑移除 previous_response_id") + require.True(t, gjson.GetBytes(firstBody, "input.0.encrypted_content").Exists(), "首次请求不应做发送前预清理") + require.Equal(t, "keep me", gjson.GetBytes(firstBody, "input.0.summary.0.text").String()) + + require.False(t, gjson.GetBytes(secondBody, "previous_response_id").Exists(), "HTTP 精确重试不应重新带回 previous_response_id") + require.False(t, gjson.GetBytes(secondBody, "input.0.encrypted_content").Exists(), "精确重试应移除 reasoning.encrypted_content") + require.Equal(t, "keep me", gjson.GetBytes(secondBody, "input.0.summary.0.text").String(), "精确重试应保留有效 reasoning summary") + require.Equal(t, "input_text", gjson.GetBytes(secondBody, "input.1.type").String(), "非 reasoning input 应保持原样") + + decision, _ := c.Get("openai_ws_transport_decision") + reason, _ := c.Get("openai_ws_transport_reason") + require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision) + require.Equal(t, "client_protocol_http", reason) +} + +func TestOpenAIGatewayService_Forward_HTTPIngressRetriesWrappedInvalidEncryptedContentOnce(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) + + upstream := &httpUpstreamSequenceRecorder{ + responses: []*http.Response{ + { + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"error":{"code":null,"message":"{\"error\":{\"message\":\"The encrypted content could not be verified.\",\"type\":\"invalid_request_error\",\"param\":null,\"code\":\"invalid_encrypted_content\"}}(traceid: fb7ad1dbc7699c18f8a02f258f1af5ab)","param":null,"type":"invalid_request_error"}}`, + )), + }, + { + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"req_http_retry_wrapped_ok"}, + }, + Body: io.NopCloser(strings.NewReader( + `{"id":"resp_http_retry_wrapped_ok","usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 103, + Name: "openai-apikey-wrapped", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_retry_wrapped","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me too"}]},{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发") + require.Equal(t, 2, upstream.callCount, "wrapped invalid_encrypted_content 也应只在 HTTP 路径重试一次") + require.Len(t, upstream.bodies, 2) + + firstBody := upstream.bodies[0] + secondBody := upstream.bodies[1] + require.True(t, gjson.GetBytes(firstBody, "input.0.encrypted_content").Exists(), "首次请求不应做发送前预清理") + require.False(t, gjson.GetBytes(secondBody, "input.0.encrypted_content").Exists(), "wrapped exact retry 应移除 reasoning.encrypted_content") + require.Equal(t, "keep me too", gjson.GetBytes(secondBody, "input.0.summary.0.text").String(), "wrapped exact retry 应保留有效 reasoning summary") + + decision, _ := c.Get("openai_ws_transport_decision") + reason, _ := c.Get("openai_ws_transport_reason") + require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision) + require.Equal(t, "client_protocol_http", reason) +} + func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) { gin.SetMode(gin.TestMode) wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -391,6 +603,8 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, + nil, cfg, nil, nil, @@ -1216,3 +1430,460 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOn require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id") require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") } + +func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentRecoversOnce(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt := wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + if attempt == 1 { + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "invalid_encrypted_content", + "type": "invalid_request_error", + "message": "The encrypted content could not be verified.", + }, + }) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_ws_invalid_encrypted_content_recover_ok", + "model": "gpt-5.3-codex", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + "input_tokens_details": map[string]any{ + "cached_tokens": 0, + }, + }, + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 95, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":[{"type":"reasoning","encrypted_content":"gAAA"},{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_ws_invalid_encrypted_content_recover_ok", result.RequestID) + require.Nil(t, upstream.lastReq, "invalid_encrypted_content 不应回退 HTTP") + require.Equal(t, int32(2), wsAttempts.Load(), "invalid_encrypted_content 应触发一次清洗后重试") + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "resp_ws_invalid_encrypted_content_recover_ok", gjson.Get(rec.Body.String(), "id").String()) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 2) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id") + require.True(t, gjson.GetBytes(requests[0], `input.0.encrypted_content`).Exists(), "首轮请求应保留 encrypted reasoning") + require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") + require.False(t, gjson.GetBytes(requests[1], `input.0.encrypted_content`).Exists(), "恢复重试应移除 encrypted reasoning item") + require.Equal(t, "input_text", gjson.GetBytes(requests[1], `input.0.type`).String()) +} + +func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentSkipsRecoveryWithoutReasoningItem(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "invalid_encrypted_content", + "type": "invalid_request_error", + "message": "The encrypted content could not be verified.", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 96, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "invalid_encrypted_content 不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "缺少 reasoning encrypted item 时应跳过自动恢复重试") + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, strings.ToLower(rec.Body.String()), "encrypted content") + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 1) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists()) + require.False(t, gjson.GetBytes(requests[0], `input.0.encrypted_content`).Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentRecoversSingleObjectInputAndKeepsSummary(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt := wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + if attempt == 1 { + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "invalid_encrypted_content", + "type": "invalid_request_error", + "message": "The encrypted content could not be verified.", + }, + }) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_ws_invalid_encrypted_content_object_ok", + "model": "gpt-5.3-codex", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + "input_tokens_details": map[string]any{ + "cached_tokens": 0, + }, + }, + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 97, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me"}]}}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_ws_invalid_encrypted_content_object_ok", result.RequestID) + require.Nil(t, upstream.lastReq, "invalid_encrypted_content 单对象 input 不应回退 HTTP") + require.Equal(t, int32(2), wsAttempts.Load(), "单对象 reasoning input 也应触发一次清洗后重试") + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 2) + require.True(t, gjson.GetBytes(requests[0], `input.encrypted_content`).Exists(), "首轮单对象应保留 encrypted_content") + require.True(t, gjson.GetBytes(requests[1], `input.summary.0.text`).Exists(), "恢复重试应保留 reasoning summary") + require.False(t, gjson.GetBytes(requests[1], `input.encrypted_content`).Exists(), "恢复重试只应移除 encrypted_content") + require.Equal(t, "reasoning", gjson.GetBytes(requests[1], `input.type`).String()) + require.False(t, gjson.GetBytes(requests[1], `previous_response_id`).Exists(), "恢复重试应移除 previous_response_id") +} + +func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentKeepsPreviousResponseIDForFunctionCallOutput(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt := wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + if attempt == 1 { + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "invalid_encrypted_content", + "type": "invalid_request_error", + "message": "The encrypted content could not be verified.", + }, + }) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_ws_invalid_encrypted_content_function_call_output_ok", + "model": "gpt-5.3-codex", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + "input_tokens_details": map[string]any{ + "cached_tokens": 0, + }, + }, + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 98, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_function_call","input":[{"type":"reasoning","encrypted_content":"gAAA"},{"type":"function_call_output","call_id":"call_123","output":"ok"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_ws_invalid_encrypted_content_function_call_output_ok", result.RequestID) + require.Nil(t, upstream.lastReq, "function_call_output + invalid_encrypted_content 不应回退 HTTP") + require.Equal(t, int32(2), wsAttempts.Load(), "应只做一次保锚点的清洗后重试") + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 2) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id") + require.True(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "function_call_output 恢复重试不应移除 previous_response_id") + require.False(t, gjson.GetBytes(requests[1], `input.0.encrypted_content`).Exists(), "恢复重试应移除 reasoning encrypted_content") + require.Equal(t, "function_call_output", gjson.GetBytes(requests[1], `input.0.type`).String(), "清洗后应保留 function_call_output 作为首个输入项") + require.Equal(t, "call_123", gjson.GetBytes(requests[1], `input.0.call_id`).String()) + require.Equal(t, "ok", gjson.GetBytes(requests[1], `input.0.output`).String()) + require.Equal(t, "resp_prev_function_call", gjson.GetBytes(requests[1], "previous_response_id").String()) +} diff --git a/backend/internal/service/openai_ws_protocol_resolver.go b/backend/internal/service/openai_ws_protocol_resolver.go index 368643be..7266759c 100644 --- a/backend/internal/service/openai_ws_protocol_resolver.go +++ b/backend/internal/service/openai_ws_protocol_resolver.go @@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt switch mode { case OpenAIWSIngressModeOff: return openAIWSHTTPDecision("account_mode_off") - case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough: // continue + case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + // 历史值兼容:按 ctx_pool 处理。 + mode = OpenAIWSIngressModeCtxPool default: return openAIWSHTTPDecision("account_mode_off") } diff --git a/backend/internal/service/openai_ws_protocol_resolver_test.go b/backend/internal/service/openai_ws_protocol_resolver_test.go index 5be76e28..4d5dc5f1 100644 --- a/backend/internal/service/openai_ws_protocol_resolver_test.go +++ b/backend/internal/service/openai_ws_protocol_resolver_test.go @@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true - cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 1, Extra: map[string]any{ - "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, }, } - t.Run("dedicated mode routes to ws v2", func(t *testing.T) { + t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) { decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account) require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) - require.Equal(t, "ws_v2_mode_dedicated", decision.Reason) + require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason) }) t.Run("off mode routes to http", func(t *testing.T) { @@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { require.Equal(t, "account_mode_off", decision.Reason) }) - t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) { + t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) { legacyAccount := &Account{ Platform: PlatformOpenAI, Type: AccountTypeAPIKey, @@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { } decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount) require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) - require.Equal(t, "ws_v2_mode_shared", decision.Reason) + require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason) + }) + + t.Run("passthrough mode routes to ws v2", func(t *testing.T) { + passthroughAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(passthroughAccount) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_passthrough", decision.Reason) }) t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) { @@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{ - "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared, + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, }, } decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency) diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go new file mode 100644 index 00000000..f5c79923 --- /dev/null +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -0,0 +1,511 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" +) + +type openAIWSRateLimitSignalRepo struct { + stubOpenAIAccountRepo + rateLimitCalls []time.Time + updateExtra []map[string]any +} + +type openAICodexSnapshotAsyncRepo struct { + stubOpenAIAccountRepo + updateExtraCh chan map[string]any + rateLimitCh chan time.Time +} + +type openAICodexExtraListRepo struct { + stubOpenAIAccountRepo + rateLimitCh chan time.Time +} + +func (r *openAIWSRateLimitSignalRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + r.rateLimitCalls = append(r.rateLimitCalls, resetAt) + return nil +} + +func (r *openAIWSRateLimitSignalRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtra = append(r.updateExtra, copied) + return nil +} + +func (r *openAICodexSnapshotAsyncRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + if r.rateLimitCh != nil { + r.rateLimitCh <- resetAt + } + return nil +} + +func (r *openAICodexSnapshotAsyncRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + if r.updateExtraCh != nil { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtraCh <- copied + } + return nil +} + +func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + if r.rateLimitCh != nil { + r.rateLimitCh <- resetAt + } + return nil +} + +func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { + _ = platform + _ = accountType + _ = status + _ = search + _ = groupID + return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil +} + +func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + resetAt := time.Now().Add(2 * time.Hour).Unix() + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { _ = conn.Close() }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "rate_limit_exceeded", + "type": "usage_limit_reached", + "message": "The usage limit has been reached", + "resets_at": resetAt, + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)), + }, + } + + cfg := newOpenAIWSV2TestConfig() + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + + account := Account{ + ID: 501, + Name: "openai-ws-rate-limit-event", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}} + rateSvc := &RateLimitService{accountRepo: repo} + svc := &OpenAIGatewayService{ + accountRepo: repo, + rateLimitService: rateSvc, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + cfg: cfg, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, &account, body) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Nil(t, upstream.lastReq, "WS 限流 error event 不应回退到同账号 HTTP") + require.Len(t, repo.rateLimitCalls, 1) + require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) +} + +func TestOpenAIGatewayService_Forward_WSv2Handshake429PersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("x-codex-primary-used-percent", "100") + w.Header().Set("x-codex-primary-reset-after-seconds", "7200") + w.Header().Set("x-codex-primary-window-minutes", "10080") + w.Header().Set("x-codex-secondary-used-percent", "3") + w.Header().Set("x-codex-secondary-reset-after-seconds", "1800") + w.Header().Set("x-codex-secondary-window-minutes", "300") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":{"type":"rate_limit_exceeded","message":"rate limited"}}`)) + })) + defer server.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)), + }, + } + + cfg := newOpenAIWSV2TestConfig() + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + + account := Account{ + ID: 502, + Name: "openai-ws-rate-limit-handshake", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": server.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}} + rateSvc := &RateLimitService{accountRepo: repo} + svc := &OpenAIGatewayService{ + accountRepo: repo, + rateLimitService: rateSvc, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + cfg: cfg, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, &account, body) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Nil(t, upstream.lastReq, "WS 握手 429 不应回退到同账号 HTTP") + require.Len(t, repo.rateLimitCalls, 1) + require.NotEmpty(t, repo.updateExtra, "握手 429 的 x-codex 头应立即落库") + require.Contains(t, repo.updateExtra[0], "codex_usage_updated_at") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageLimitPersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := newOpenAIWSV2TestConfig() + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + resetAt := time.Now().Add(90 * time.Minute).Unix() + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached","resets_at":PLACEHOLDER}}`), + }, + } + captureConn.events[0] = []byte(strings.ReplaceAll(string(captureConn.events[0]), "PLACEHOLDER", strconv.FormatInt(resetAt, 10))) + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + account := Account{ + ID: 503, + Name: "openai-ingress-rate-limit", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}} + rateSvc := &RateLimitService{accountRepo: repo} + svc := &OpenAIGatewayService{ + accountRepo: repo, + rateLimitService: rateSvc, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + cfg: cfg, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover}) + if err != nil { + serverErrCh <- err + return + } + defer func() { _ = conn.CloseNow() }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- io.ErrUnexpectedEOF + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, &account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + select { + case serverErr := <-serverErrCh: + require.Error(t, serverErr) + require.Len(t, repo.rateLimitCalls, 1) + require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } +} + +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRateLimit(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 1), + rateLimitCh: make(chan time.Time, 1), + } + svc := &OpenAIGatewayService{accountRepo: repo} + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(100), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(12), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + before := time.Now() + svc.updateCodexUsageSnapshot(context.Background(), 601, snapshot) + + select { + case updates := <-repo.updateExtraCh: + require.Equal(t, 100.0, updates["codex_7d_used_percent"]) + case <-time.After(2 * time.Second): + t.Fatal("等待 codex 快照落库超时") + } + + select { + case resetAt := <-repo.rateLimitCh: + require.WithinDuration(t, before.Add(time.Hour), resetAt, 2*time.Second) + case <-time.After(2 * time.Second): + t.Fatal("等待 codex 100% 自动切换限流超时") + } +} + +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesNotSetRateLimit(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 1), + rateLimitCh: make(chan time.Time, 1), + } + svc := &OpenAIGatewayService{accountRepo: repo} + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(94), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(22), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + svc.updateCodexUsageSnapshot(context.Background(), 602, snapshot) + + select { + case <-repo.updateExtraCh: + case <-time.After(2 * time.Second): + t.Fatal("等待 codex 快照落库超时") + } + + select { + case resetAt := <-repo.rateLimitCh: + t.Fatalf("unexpected rate limit reset at: %v", resetAt) + case <-time.After(200 * time.Millisecond): + } +} + +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 2), + rateLimitCh: make(chan time.Time, 2), + } + svc := &OpenAIGatewayService{ + accountRepo: repo, + codexSnapshotThrottle: newAccountWriteThrottle(time.Hour), + } + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(94), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(22), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + + svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot) + svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot) + + select { + case <-repo.updateExtraCh: + case <-time.After(2 * time.Second): + t.Fatal("等待第一次 codex 快照落库超时") + } + + select { + case updates := <-repo.updateExtraCh: + t.Fatalf("unexpected second codex snapshot write: %v", updates) + case <-time.After(200 * time.Millisecond): + } +} + +func ptrFloat64WS(v float64) *float64 { return &v } +func ptrIntWS(v int) *int { return &v } + +func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateLimit(t *testing.T) { + resetAt := time.Now().Add(6 * 24 * time.Hour) + account := Account{ + ID: 701, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339), + }, + } + repo := &openAICodexExtraListRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, rateLimitCh: make(chan time.Time, 1)} + svc := &OpenAIGatewayService{accountRepo: repo} + + fresh, err := svc.getSchedulableAccount(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, fresh) + require.NotNil(t, fresh.RateLimitResetAt) + require.WithinDuration(t, resetAt.UTC(), *fresh.RateLimitResetAt, time.Second) + select { + case persisted := <-repo.rateLimitCh: + require.WithinDuration(t, resetAt.UTC(), persisted, time.Second) + case <-time.After(2 * time.Second): + t.Fatal("等待旧快照补写限流状态超时") + } +} + +func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(t *testing.T) { + resetAt := time.Now().Add(4 * 24 * time.Hour) + repo := &openAICodexExtraListRepo{ + stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{{ + ID: 702, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339), + }, + }}}, + rateLimitCh: make(chan time.Time, 1), + } + svc := &adminServiceImpl{accountRepo: repo} + + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0) + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Len(t, accounts, 1) + require.NotNil(t, accounts[0].RateLimitResetAt) + require.WithinDuration(t, resetAt.UTC(), *accounts[0].RateLimitResetAt, time.Second) + select { + case persisted := <-repo.rateLimitCh: + require.WithinDuration(t, resetAt.UTC(), persisted, time.Second) + case <-time.After(2 * time.Second): + t.Fatal("等待列表补写限流状态超时") + } +} + +func TestOpenAIWSErrorHTTPStatusFromRaw_UsageLimitReachedIs429(t *testing.T) { + require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("", "usage_limit_reached")) + require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("rate_limit_exceeded", "")) +} diff --git a/backend/internal/service/openai_ws_v2/caddy_adapter.go b/backend/internal/service/openai_ws_v2/caddy_adapter.go new file mode 100644 index 00000000..1fecc231 --- /dev/null +++ b/backend/internal/service/openai_ws_v2/caddy_adapter.go @@ -0,0 +1,24 @@ +package openai_ws_v2 + +import ( + "context" +) + +// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想: +// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。 +// +// Reference: +// - Project: caddyserver/caddy (Apache-2.0) +// - Commit: f283062d37c50627d53ca682ebae2ce219b35515 +// - Files: +// - modules/caddyhttp/reverseproxy/streaming.go +// - modules/caddyhttp/reverseproxy/reverseproxy.go +func runCaddyStyleRelay( + ctx context.Context, + clientConn FrameConn, + upstreamConn FrameConn, + firstClientMessage []byte, + options RelayOptions, +) (RelayResult, *RelayExit) { + return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options) +} diff --git a/backend/internal/service/openai_ws_v2/entry.go b/backend/internal/service/openai_ws_v2/entry.go new file mode 100644 index 00000000..176298fe --- /dev/null +++ b/backend/internal/service/openai_ws_v2/entry.go @@ -0,0 +1,23 @@ +package openai_ws_v2 + +import "context" + +// EntryInput 是 passthrough v2 数据面的入口参数。 +type EntryInput struct { + Ctx context.Context + ClientConn FrameConn + UpstreamConn FrameConn + FirstClientMessage []byte + Options RelayOptions +} + +// RunEntry 是 openai_ws_v2 包对外的统一入口。 +func RunEntry(input EntryInput) (RelayResult, *RelayExit) { + return runCaddyStyleRelay( + input.Ctx, + input.ClientConn, + input.UpstreamConn, + input.FirstClientMessage, + input.Options, + ) +} diff --git a/backend/internal/service/openai_ws_v2/metrics.go b/backend/internal/service/openai_ws_v2/metrics.go new file mode 100644 index 00000000..3708befd --- /dev/null +++ b/backend/internal/service/openai_ws_v2/metrics.go @@ -0,0 +1,29 @@ +package openai_ws_v2 + +import ( + "sync/atomic" +) + +// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。 +type MetricsSnapshot struct { + SemanticMutationTotal int64 `json:"semantic_mutation_total"` + UsageParseFailureTotal int64 `json:"usage_parse_failure_total"` +} + +var ( + // passthrough 路径默认不会做语义改写,该计数通常应保持为 0(保留用于未来防御性校验)。 + passthroughSemanticMutationTotal atomic.Int64 + passthroughUsageParseFailureTotal atomic.Int64 +) + +func recordUsageParseFailure() { + passthroughUsageParseFailureTotal.Add(1) +} + +// SnapshotMetrics 返回当前 passthrough 指标快照。 +func SnapshotMetrics() MetricsSnapshot { + return MetricsSnapshot{ + SemanticMutationTotal: passthroughSemanticMutationTotal.Load(), + UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(), + } +} diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay.go b/backend/internal/service/openai_ws_v2/passthrough_relay.go new file mode 100644 index 00000000..af8ee195 --- /dev/null +++ b/backend/internal/service/openai_ws_v2/passthrough_relay.go @@ -0,0 +1,807 @@ +package openai_ws_v2 + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "strings" + "sync/atomic" + "time" + + coderws "github.com/coder/websocket" + "github.com/tidwall/gjson" +) + +type FrameConn interface { + ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) + WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error + Close() error +} + +type Usage struct { + InputTokens int + OutputTokens int + CacheCreationInputTokens int + CacheReadInputTokens int +} + +type RelayResult struct { + RequestModel string + Usage Usage + RequestID string + TerminalEventType string + FirstTokenMs *int + Duration time.Duration + ClientToUpstreamFrames int64 + UpstreamToClientFrames int64 + DroppedDownstreamFrames int64 +} + +type RelayTurnResult struct { + RequestModel string + Usage Usage + RequestID string + TerminalEventType string + Duration time.Duration + FirstTokenMs *int +} + +type RelayExit struct { + Stage string + Err error + WroteDownstream bool +} + +type RelayOptions struct { + WriteTimeout time.Duration + IdleTimeout time.Duration + UpstreamDrainTimeout time.Duration + FirstMessageType coderws.MessageType + OnUsageParseFailure func(eventType string, usageRaw string) + OnTurnComplete func(turn RelayTurnResult) + OnTrace func(event RelayTraceEvent) + Now func() time.Time +} + +type RelayTraceEvent struct { + Stage string + Direction string + MessageType string + PayloadBytes int + Graceful bool + WroteDownstream bool + Error string +} + +type relayState struct { + usage Usage + requestModel string + lastResponseID string + terminalEventType string + firstTokenMs *int + turnTimingByID map[string]*relayTurnTiming +} + +type relayExitSignal struct { + stage string + err error + graceful bool + wroteDownstream bool +} + +type observedUpstreamEvent struct { + terminal bool + eventType string + responseID string + usage Usage + duration time.Duration + firstToken *int +} + +type relayTurnTiming struct { + startAt time.Time + firstTokenMs *int +} + +func Relay( + ctx context.Context, + clientConn FrameConn, + upstreamConn FrameConn, + firstClientMessage []byte, + options RelayOptions, +) (RelayResult, *RelayExit) { + result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())} + if clientConn == nil || upstreamConn == nil { + return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")} + } + if ctx == nil { + ctx = context.Background() + } + + nowFn := options.Now + if nowFn == nil { + nowFn = time.Now + } + writeTimeout := options.WriteTimeout + if writeTimeout <= 0 { + writeTimeout = 2 * time.Minute + } + drainTimeout := options.UpstreamDrainTimeout + if drainTimeout <= 0 { + drainTimeout = 1200 * time.Millisecond + } + firstMessageType := options.FirstMessageType + if firstMessageType != coderws.MessageBinary { + firstMessageType = coderws.MessageText + } + startAt := nowFn() + state := &relayState{requestModel: result.RequestModel} + onTrace := options.OnTrace + + relayCtx, relayCancel := context.WithCancel(ctx) + defer relayCancel() + + lastActivity := atomic.Int64{} + lastActivity.Store(nowFn().UnixNano()) + markActivity := func() { + lastActivity.Store(nowFn().UnixNano()) + } + + writeUpstream := func(msgType coderws.MessageType, payload []byte) error { + writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout) + defer cancel() + return upstreamConn.WriteFrame(writeCtx, msgType, payload) + } + writeClient := func(msgType coderws.MessageType, payload []byte) error { + writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout) + defer cancel() + return clientConn.WriteFrame(writeCtx, msgType, payload) + } + + clientToUpstreamFrames := &atomic.Int64{} + upstreamToClientFrames := &atomic.Int64{} + droppedDownstreamFrames := &atomic.Int64{} + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_start", + PayloadBytes: len(firstClientMessage), + MessageType: relayMessageTypeString(firstMessageType), + }) + + if err := writeUpstream(firstMessageType, firstClientMessage); err != nil { + result.Duration = nowFn().Sub(startAt) + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_first_message_failed", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(firstMessageType), + PayloadBytes: len(firstClientMessage), + Error: err.Error(), + }) + return result, &RelayExit{Stage: "write_upstream", Err: err} + } + clientToUpstreamFrames.Add(1) + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_first_message_ok", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(firstMessageType), + PayloadBytes: len(firstClientMessage), + }) + markActivity() + + exitCh := make(chan relayExitSignal, 3) + dropDownstreamWrites := atomic.Bool{} + go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh) + go runUpstreamToClient( + relayCtx, + upstreamConn, + writeClient, + startAt, + nowFn, + state, + options.OnUsageParseFailure, + options.OnTurnComplete, + &dropDownstreamWrites, + upstreamToClientFrames, + droppedDownstreamFrames, + markActivity, + onTrace, + exitCh, + ) + go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh) + + firstExit := <-exitCh + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "first_exit", + Direction: relayDirectionFromStage(firstExit.stage), + Graceful: firstExit.graceful, + WroteDownstream: firstExit.wroteDownstream, + Error: relayErrorString(firstExit.err), + }) + combinedWroteDownstream := firstExit.wroteDownstream + secondExit := relayExitSignal{graceful: true} + hasSecondExit := false + + // 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。 + if firstExit.stage == "read_client" && firstExit.graceful { + dropDownstreamWrites.Store(true) + secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout) + } else { + relayCancel() + _ = upstreamConn.Close() + secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond) + } + if hasSecondExit { + combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "second_exit", + Direction: relayDirectionFromStage(secondExit.stage), + Graceful: secondExit.graceful, + WroteDownstream: secondExit.wroteDownstream, + Error: relayErrorString(secondExit.err), + }) + } + + relayCancel() + _ = upstreamConn.Close() + + enrichResult(&result, state, nowFn().Sub(startAt)) + result.ClientToUpstreamFrames = clientToUpstreamFrames.Load() + result.UpstreamToClientFrames = upstreamToClientFrames.Load() + result.DroppedDownstreamFrames = droppedDownstreamFrames.Load() + if firstExit.stage == "read_client" && firstExit.graceful { + stage := "client_disconnected" + exitErr := firstExit.err + if hasSecondExit && !secondExit.graceful { + stage = secondExit.stage + exitErr = secondExit.err + } + if exitErr == nil { + exitErr = io.EOF + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_exit", + Direction: relayDirectionFromStage(stage), + Graceful: false, + WroteDownstream: combinedWroteDownstream, + Error: relayErrorString(exitErr), + }) + return result, &RelayExit{ + Stage: stage, + Err: exitErr, + WroteDownstream: combinedWroteDownstream, + } + } + if firstExit.graceful && (!hasSecondExit || secondExit.graceful) { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_complete", + Graceful: true, + WroteDownstream: combinedWroteDownstream, + }) + _ = clientConn.Close() + return result, nil + } + if !firstExit.graceful { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_exit", + Direction: relayDirectionFromStage(firstExit.stage), + Graceful: false, + WroteDownstream: combinedWroteDownstream, + Error: relayErrorString(firstExit.err), + }) + return result, &RelayExit{ + Stage: firstExit.stage, + Err: firstExit.err, + WroteDownstream: combinedWroteDownstream, + } + } + if hasSecondExit && !secondExit.graceful { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_exit", + Direction: relayDirectionFromStage(secondExit.stage), + Graceful: false, + WroteDownstream: combinedWroteDownstream, + Error: relayErrorString(secondExit.err), + }) + return result, &RelayExit{ + Stage: secondExit.stage, + Err: secondExit.err, + WroteDownstream: combinedWroteDownstream, + } + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_complete", + Graceful: true, + WroteDownstream: combinedWroteDownstream, + }) + _ = clientConn.Close() + return result, nil +} + +func runClientToUpstream( + ctx context.Context, + clientConn FrameConn, + writeUpstream func(msgType coderws.MessageType, payload []byte) error, + markActivity func(), + forwardedFrames *atomic.Int64, + onTrace func(event RelayTraceEvent), + exitCh chan<- relayExitSignal, +) { + for { + msgType, payload, err := clientConn.ReadFrame(ctx) + if err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "read_client_failed", + Direction: "client_to_upstream", + Error: err.Error(), + Graceful: isDisconnectError(err), + }) + exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)} + return + } + markActivity() + if err := writeUpstream(msgType, payload); err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_upstream_failed", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + Error: err.Error(), + }) + exitCh <- relayExitSignal{stage: "write_upstream", err: err} + return + } + if forwardedFrames != nil { + forwardedFrames.Add(1) + } + markActivity() + } +} + +func runUpstreamToClient( + ctx context.Context, + upstreamConn FrameConn, + writeClient func(msgType coderws.MessageType, payload []byte) error, + startAt time.Time, + nowFn func() time.Time, + state *relayState, + onUsageParseFailure func(eventType string, usageRaw string), + onTurnComplete func(turn RelayTurnResult), + dropDownstreamWrites *atomic.Bool, + forwardedFrames *atomic.Int64, + droppedFrames *atomic.Int64, + markActivity func(), + onTrace func(event RelayTraceEvent), + exitCh chan<- relayExitSignal, +) { + wroteDownstream := false + for { + msgType, payload, err := upstreamConn.ReadFrame(ctx) + if err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "read_upstream_failed", + Direction: "upstream_to_client", + Error: err.Error(), + Graceful: isDisconnectError(err), + WroteDownstream: wroteDownstream, + }) + exitCh <- relayExitSignal{ + stage: "read_upstream", + err: err, + graceful: isDisconnectError(err), + wroteDownstream: wroteDownstream, + } + return + } + markActivity() + observedEvent := observedUpstreamEvent{} + switch msgType { + case coderws.MessageText: + observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure) + case coderws.MessageBinary: + // binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。 + } + emitTurnComplete(onTurnComplete, state, observedEvent) + if dropDownstreamWrites != nil && dropDownstreamWrites.Load() { + if droppedFrames != nil { + droppedFrames.Add(1) + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "drop_downstream_frame", + Direction: "upstream_to_client", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + WroteDownstream: wroteDownstream, + }) + if observedEvent.terminal { + exitCh <- relayExitSignal{ + stage: "drain_terminal", + graceful: true, + wroteDownstream: wroteDownstream, + } + return + } + markActivity() + continue + } + if err := writeClient(msgType, payload); err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_client_failed", + Direction: "upstream_to_client", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + WroteDownstream: wroteDownstream, + Error: err.Error(), + }) + exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream} + return + } + wroteDownstream = true + if forwardedFrames != nil { + forwardedFrames.Add(1) + } + markActivity() + } +} + +func runIdleWatchdog( + ctx context.Context, + nowFn func() time.Time, + idleTimeout time.Duration, + lastActivity *atomic.Int64, + onTrace func(event RelayTraceEvent), + exitCh chan<- relayExitSignal, +) { + if idleTimeout <= 0 { + return + } + checkInterval := minDuration(idleTimeout/4, 5*time.Second) + if checkInterval < time.Second { + checkInterval = time.Second + } + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + last := time.Unix(0, lastActivity.Load()) + if nowFn().Sub(last) < idleTimeout { + continue + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "idle_timeout_triggered", + Direction: "watchdog", + Error: context.DeadlineExceeded.Error(), + }) + exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded} + return + } + } +} + +func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) { + if onTrace == nil { + return + } + onTrace(event) +} + +func relayMessageTypeString(msgType coderws.MessageType) string { + switch msgType { + case coderws.MessageText: + return "text" + case coderws.MessageBinary: + return "binary" + default: + return "unknown(" + strconv.Itoa(int(msgType)) + ")" + } +} + +func relayDirectionFromStage(stage string) string { + switch stage { + case "read_client", "write_upstream": + return "client_to_upstream" + case "read_upstream", "write_client", "drain_terminal": + return "upstream_to_client" + case "idle_timeout": + return "watchdog" + default: + return "" + } +} + +func relayErrorString(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func observeUpstreamMessage( + state *relayState, + message []byte, + startAt time.Time, + nowFn func() time.Time, + onUsageParseFailure func(eventType string, usageRaw string), +) observedUpstreamEvent { + if state == nil || len(message) == 0 { + return observedUpstreamEvent{} + } + values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id") + eventType := strings.TrimSpace(values[0].String()) + if eventType == "" { + return observedUpstreamEvent{} + } + responseID := strings.TrimSpace(values[1].String()) + if responseID == "" { + responseID = strings.TrimSpace(values[2].String()) + } + // 仅 terminal 事件兜底读取顶层 id,避免把 event_id 当成 response_id 关联到 turn。 + if responseID == "" && isTerminalEvent(eventType) { + responseID = strings.TrimSpace(values[3].String()) + } + now := nowFn() + + if state.firstTokenMs == nil && isTokenEvent(eventType) { + ms := int(now.Sub(startAt).Milliseconds()) + if ms >= 0 { + state.firstTokenMs = &ms + } + } + parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure) + observed := observedUpstreamEvent{ + eventType: eventType, + responseID: responseID, + usage: parsedUsage, + } + if responseID != "" { + turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now) + if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) { + ms := int(now.Sub(turnTiming.startAt).Milliseconds()) + if ms >= 0 { + turnTiming.firstTokenMs = &ms + } + } + } + if !isTerminalEvent(eventType) { + return observed + } + observed.terminal = true + state.terminalEventType = eventType + if responseID != "" { + state.lastResponseID = responseID + if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok { + duration := now.Sub(turnTiming.startAt) + if duration < 0 { + duration = 0 + } + observed.duration = duration + observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs) + } + } + return observed +} + +func emitTurnComplete( + onTurnComplete func(turn RelayTurnResult), + state *relayState, + observed observedUpstreamEvent, +) { + if onTurnComplete == nil || !observed.terminal { + return + } + responseID := strings.TrimSpace(observed.responseID) + if responseID == "" { + return + } + requestModel := "" + if state != nil { + requestModel = state.requestModel + } + onTurnComplete(RelayTurnResult{ + RequestModel: requestModel, + Usage: observed.usage, + RequestID: responseID, + TerminalEventType: observed.eventType, + Duration: observed.duration, + FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken), + }) +} + +func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming { + if state == nil { + return nil + } + if state.turnTimingByID == nil { + state.turnTimingByID = make(map[string]*relayTurnTiming, 8) + } + timing, ok := state.turnTimingByID[responseID] + if !ok || timing == nil || timing.startAt.IsZero() { + timing = &relayTurnTiming{startAt: now} + state.turnTimingByID[responseID] = timing + return timing + } + return timing +} + +func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) { + if state == nil || state.turnTimingByID == nil { + return relayTurnTiming{}, false + } + timing, ok := state.turnTimingByID[responseID] + if !ok || timing == nil { + return relayTurnTiming{}, false + } + delete(state.turnTimingByID, responseID) + return *timing, true +} + +func openAIWSRelayCloneIntPtr(v *int) *int { + if v == nil { + return nil + } + cloned := *v + return &cloned +} + +func parseUsageAndAccumulate( + state *relayState, + message []byte, + eventType string, + onParseFailure func(eventType string, usageRaw string), +) Usage { + if state == nil || len(message) == 0 || !shouldParseUsage(eventType) { + return Usage{} + } + usageResult := gjson.GetBytes(message, "response.usage") + if !usageResult.Exists() { + return Usage{} + } + usageRaw := strings.TrimSpace(usageResult.Raw) + if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") { + recordUsageParseFailure() + if onParseFailure != nil { + onParseFailure(eventType, usageRaw) + } + return Usage{} + } + + inputResult := gjson.GetBytes(message, "response.usage.input_tokens") + outputResult := gjson.GetBytes(message, "response.usage.output_tokens") + cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens") + + inputTokens, inputOK := parseUsageIntField(inputResult, true) + outputTokens, outputOK := parseUsageIntField(outputResult, true) + cachedTokens, cachedOK := parseUsageIntField(cachedResult, false) + if !inputOK || !outputOK || !cachedOK { + recordUsageParseFailure() + if onParseFailure != nil { + onParseFailure(eventType, usageRaw) + } + // 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。 + return Usage{} + } + parsedUsage := Usage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheReadInputTokens: cachedTokens, + } + + state.usage.InputTokens += parsedUsage.InputTokens + state.usage.OutputTokens += parsedUsage.OutputTokens + state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens + return parsedUsage +} + +func parseUsageIntField(value gjson.Result, required bool) (int, bool) { + if !value.Exists() { + return 0, !required + } + if value.Type != gjson.Number { + return 0, false + } + return int(value.Int()), true +} + +func enrichResult(result *RelayResult, state *relayState, duration time.Duration) { + if result == nil { + return + } + result.Duration = duration + if state == nil { + return + } + result.RequestModel = state.requestModel + result.Usage = state.usage + result.RequestID = state.lastResponseID + result.TerminalEventType = state.terminalEventType + result.FirstTokenMs = state.firstTokenMs +} + +func isDisconnectError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return true + } + switch coderws.CloseStatus(err) { + case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + if message == "" { + return false + } + return strings.Contains(message, "failed to read frame header: eof") || + strings.Contains(message, "unexpected eof") || + strings.Contains(message, "use of closed network connection") || + strings.Contains(message, "connection reset by peer") || + strings.Contains(message, "broken pipe") +} + +func isTerminalEvent(eventType string) bool { + switch eventType { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } +} + +func shouldParseUsage(eventType string) bool { + switch eventType { + case "response.completed", "response.done", "response.failed": + return true + default: + return false + } +} + +func isTokenEvent(eventType string) bool { + if eventType == "" { + return false + } + switch eventType { + case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": + return false + } + if strings.Contains(eventType, ".delta") { + return true + } + if strings.HasPrefix(eventType, "response.output_text") { + return true + } + if strings.HasPrefix(eventType, "response.output") { + return true + } + return eventType == "response.completed" || eventType == "response.done" +} + +func minDuration(a, b time.Duration) time.Duration { + if a <= 0 { + return b + } + if b <= 0 { + return a + } + if a < b { + return a + } + return b +} + +func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) { + if timeout <= 0 { + timeout = 200 * time.Millisecond + } + select { + case sig := <-exitCh: + return sig, true + case <-time.After(timeout): + return relayExitSignal{}, false + } +} diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go new file mode 100644 index 00000000..123e10ce --- /dev/null +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go @@ -0,0 +1,432 @@ +package openai_ws_v2 + +import ( + "context" + "errors" + "io" + "net" + "sync/atomic" + "testing" + "time" + + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestRunEntry_DelegatesRelay(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + result, relayExit := RunEntry(EntryInput{ + Ctx: context.Background(), + ClientConn: clientConn, + UpstreamConn: upstreamConn, + FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`), + }) + require.Nil(t, relayExit) + require.Equal(t, "resp_entry", result.RequestID) +} + +func TestRunClientToUpstream_ErrorPaths(t *testing.T) { + t.Parallel() + + t.Run("read client eof", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + runClientToUpstream( + context.Background(), + newPassthroughTestFrameConn(nil, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + func() {}, + nil, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "read_client", sig.stage) + require.True(t, sig.graceful) + }) + + t.Run("write upstream failed", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + runClientToUpstream( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, + }, true), + func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") }, + func() {}, + nil, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "write_upstream", sig.stage) + require.False(t, sig.graceful) + }) + + t.Run("forwarded counter and trace callback", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + forwarded := &atomic.Int64{} + traces := make([]RelayTraceEvent, 0, 2) + runClientToUpstream( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, + }, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + func() {}, + forwarded, + func(event RelayTraceEvent) { + traces = append(traces, event) + }, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "read_client", sig.stage) + require.Equal(t, int64(1), forwarded.Load()) + require.NotEmpty(t, traces) + }) +} + +func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) { + t.Parallel() + + t.Run("read upstream eof", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + drop := &atomic.Bool{} + drop.Store(false) + runUpstreamToClient( + context.Background(), + newPassthroughTestFrameConn(nil, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + time.Now(), + time.Now, + &relayState{}, + nil, + nil, + drop, + nil, + nil, + func() {}, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "read_upstream", sig.stage) + require.True(t, sig.graceful) + }) + + t.Run("write client failed", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + drop := &atomic.Bool{} + drop.Store(false) + runUpstreamToClient( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + {msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)}, + }, true), + func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") }, + time.Now(), + time.Now, + &relayState{}, + nil, + nil, + drop, + nil, + nil, + func() {}, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "write_client", sig.stage) + }) + + t.Run("drop downstream and stop on terminal", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + drop := &atomic.Bool{} + drop.Store(true) + dropped := &atomic.Int64{} + runUpstreamToClient( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + time.Now(), + time.Now, + &relayState{}, + nil, + nil, + drop, + nil, + dropped, + func() {}, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "drain_terminal", sig.stage) + require.True(t, sig.graceful) + require.Equal(t, int64(1), dropped.Load()) + }) +} + +func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + lastActivity := &atomic.Int64{} + lastActivity.Store(time.Now().UnixNano()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh) + select { + case <-exitCh: + t.Fatal("unexpected idle timeout signal") + case <-time.After(200 * time.Millisecond): + } +} + +func TestHelperFunctionsCoverage(t *testing.T) { + t.Parallel() + + require.Equal(t, "text", relayMessageTypeString(coderws.MessageText)) + require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary)) + require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(") + + require.Equal(t, "", relayErrorString(nil)) + require.Equal(t, "x", relayErrorString(errors.New("x"))) + + require.True(t, isDisconnectError(io.EOF)) + require.True(t, isDisconnectError(net.ErrClosed)) + require.True(t, isDisconnectError(context.Canceled)) + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway})) + require.True(t, isDisconnectError(errors.New("broken pipe"))) + require.False(t, isDisconnectError(errors.New("unrelated"))) + + require.True(t, isTokenEvent("response.output_text.delta")) + require.True(t, isTokenEvent("response.output_audio.delta")) + require.True(t, isTokenEvent("response.completed")) + require.False(t, isTokenEvent("")) + require.False(t, isTokenEvent("response.created")) + + require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second)) + require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second)) + require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second)) + require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0)) + + ch := make(chan relayExitSignal, 1) + ch <- relayExitSignal{stage: "ok"} + sig, ok := waitRelayExit(ch, 10*time.Millisecond) + require.True(t, ok) + require.Equal(t, "ok", sig.stage) + ch <- relayExitSignal{stage: "ok2"} + sig, ok = waitRelayExit(ch, 0) + require.True(t, ok) + require.Equal(t, "ok2", sig.stage) + _, ok = waitRelayExit(ch, 10*time.Millisecond) + require.False(t, ok) + + n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true) + require.True(t, ok) + require.Equal(t, 3, n) + _, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true) + require.False(t, ok) + n, ok = parseUsageIntField(gjson.Result{}, false) + require.True(t, ok) + require.Equal(t, 0, n) + _, ok = parseUsageIntField(gjson.Result{}, true) + require.False(t, ok) +} + +func TestParseUsageAndEnrichCoverage(t *testing.T) { + t.Parallel() + + state := &relayState{} + parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil) + require.Equal(t, 0, state.usage.InputTokens) + + parseUsageAndAccumulate( + state, + []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`), + "response.completed", + nil, + ) + require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage") + require.Equal(t, 0, state.usage.OutputTokens) + require.Equal(t, 0, state.usage.CacheReadInputTokens) + + parseUsageAndAccumulate( + state, + []byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`), + "response.completed", + nil, + ) + require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage") + require.Equal(t, 0, state.usage.OutputTokens) + require.Equal(t, 0, state.usage.CacheReadInputTokens) + + parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil) + require.Equal(t, 2, state.usage.InputTokens) + require.Equal(t, 1, state.usage.OutputTokens) + require.Equal(t, 1, state.usage.CacheReadInputTokens) + + result := &RelayResult{} + enrichResult(result, state, 5*time.Millisecond) + require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens) + require.Equal(t, 5*time.Millisecond, result.Duration) + parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil) + require.Equal(t, 2, state.usage.InputTokens) + enrichResult(nil, state, 0) +} + +func TestEmitTurnCompleteCoverage(t *testing.T) { + t.Parallel() + + // 非 terminal 事件不应触发。 + called := 0 + emitTurnComplete(func(turn RelayTurnResult) { + called++ + }, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{ + terminal: false, + eventType: "response.output_text.delta", + responseID: "resp_ignored", + usage: Usage{InputTokens: 1}, + }) + require.Equal(t, 0, called) + + // 缺少 response_id 时不应触发。 + emitTurnComplete(func(turn RelayTurnResult) { + called++ + }, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{ + terminal: true, + eventType: "response.completed", + }) + require.Equal(t, 0, called) + + // terminal 且 response_id 存在,应该触发;state=nil 时 model 为空串。 + var got RelayTurnResult + emitTurnComplete(func(turn RelayTurnResult) { + called++ + got = turn + }, nil, observedUpstreamEvent{ + terminal: true, + eventType: "response.completed", + responseID: "resp_emit", + usage: Usage{InputTokens: 2, OutputTokens: 3}, + }) + require.Equal(t, 1, called) + require.Equal(t, "resp_emit", got.RequestID) + require.Equal(t, "response.completed", got.TerminalEventType) + require.Equal(t, 2, got.Usage.InputTokens) + require.Equal(t, 3, got.Usage.OutputTokens) + require.Equal(t, "", got.RequestModel) +} + +func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) { + t.Parallel() + + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure})) + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd})) + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure})) + require.True(t, isDisconnectError(errors.New("connection reset by peer"))) + require.False(t, isDisconnectError(errors.New(" "))) +} + +func TestIsTokenEventCoverageBranches(t *testing.T) { + t.Parallel() + + require.False(t, isTokenEvent("response.in_progress")) + require.False(t, isTokenEvent("response.output_item.added")) + require.True(t, isTokenEvent("response.output_audio.delta")) + require.True(t, isTokenEvent("response.output")) + require.True(t, isTokenEvent("response.done")) +} + +func TestRelayTurnTimingHelpersCoverage(t *testing.T) { + t.Parallel() + + now := time.Unix(100, 0) + // nil state + require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now)) + _, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil") + require.False(t, ok) + + state := &relayState{} + timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now) + require.NotNil(t, timing) + require.Equal(t, now, timing.startAt) + + // 再次获取返回同一条 timing + timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second)) + require.NotNil(t, timing2) + require.Equal(t, now, timing2.startAt) + + // 删除存在键 + deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a") + require.True(t, ok) + require.Equal(t, now, deleted.startAt) + + // 删除不存在键 + _, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a") + require.False(t, ok) +} + +func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) { + t.Parallel() + + state := &relayState{requestModel: "gpt-5"} + startAt := time.Unix(0, 0) + now := startAt + nowFn := func() time.Time { + now = now.Add(5 * time.Millisecond) + return now + } + + // 非 terminal:仅有顶层 id,不应把 event id 当成 response_id。 + observed := observeUpstreamMessage( + state, + []byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`), + startAt, + nowFn, + nil, + ) + require.False(t, observed.terminal) + require.Equal(t, "", observed.responseID) + + // terminal:允许兜底用顶层 id(用于兼容少数字段变体)。 + observed = observeUpstreamMessage( + state, + []byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`), + startAt, + nowFn, + nil, + ) + require.True(t, observed.terminal) + require.Equal(t, "resp_fallback", observed.responseID) +} diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go new file mode 100644 index 00000000..ff9b7311 --- /dev/null +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go @@ -0,0 +1,752 @@ +package openai_ws_v2 + +import ( + "context" + "errors" + "io" + "sync" + "sync/atomic" + "testing" + "time" + + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" +) + +type passthroughTestFrame struct { + msgType coderws.MessageType + payload []byte +} + +type passthroughTestFrameConn struct { + mu sync.Mutex + writes []passthroughTestFrame + readCh chan passthroughTestFrame + once sync.Once +} + +type delayedReadFrameConn struct { + base FrameConn + firstDelay time.Duration + once sync.Once +} + +type closeSpyFrameConn struct { + closeCalls atomic.Int32 +} + +func newPassthroughTestFrameConn(frames []passthroughTestFrame, autoClose bool) *passthroughTestFrameConn { + c := &passthroughTestFrameConn{ + readCh: make(chan passthroughTestFrame, len(frames)+1), + } + for _, frame := range frames { + copied := passthroughTestFrame{msgType: frame.msgType, payload: append([]byte(nil), frame.payload...)} + c.readCh <- copied + } + if autoClose { + close(c.readCh) + } + return c +} + +func (c *passthroughTestFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return coderws.MessageText, nil, ctx.Err() + case frame, ok := <-c.readCh: + if !ok { + return coderws.MessageText, nil, io.EOF + } + return frame.msgType, append([]byte(nil), frame.payload...), nil + } +} + +func (c *passthroughTestFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + c.mu.Lock() + defer c.mu.Unlock() + c.writes = append(c.writes, passthroughTestFrame{msgType: msgType, payload: append([]byte(nil), payload...)}) + return nil +} + +func (c *passthroughTestFrameConn) Close() error { + c.once.Do(func() { + defer func() { _ = recover() }() + close(c.readCh) + }) + return nil +} + +func (c *passthroughTestFrameConn) Writes() []passthroughTestFrame { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]passthroughTestFrame, len(c.writes)) + copy(out, c.writes) + return out +} + +func (c *delayedReadFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.base == nil { + return coderws.MessageText, nil, io.EOF + } + c.once.Do(func() { + if c.firstDelay > 0 { + timer := time.NewTimer(c.firstDelay) + defer timer.Stop() + select { + case <-ctx.Done(): + case <-timer.C: + } + } + }) + return c.base.ReadFrame(ctx) +} + +func (c *delayedReadFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.base == nil { + return io.EOF + } + return c.base.WriteFrame(ctx, msgType, payload) +} + +func (c *delayedReadFrameConn) Close() error { + if c == nil || c.base == nil { + return nil + } + return c.base.Close() +} + +func (c *closeSpyFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if ctx == nil { + ctx = context.Background() + } + <-ctx.Done() + return coderws.MessageText, nil, ctx.Err() +} + +func (c *closeSpyFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } +} + +func (c *closeSpyFrameConn) Close() error { + if c != nil { + c.closeCalls.Add(1) + } + return nil +} + +func (c *closeSpyFrameConn) CloseCalls() int32 { + if c == nil { + return 0 + } + return c.closeCalls.Load() +} + +func TestRelay_BasicRelayAndUsage(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + require.Equal(t, "gpt-5.3-codex", result.RequestModel) + require.Equal(t, "resp_123", result.RequestID) + require.Equal(t, "response.completed", result.TerminalEventType) + require.Equal(t, 7, result.Usage.InputTokens) + require.Equal(t, 3, result.Usage.OutputTokens) + require.Equal(t, 2, result.Usage.CacheReadInputTokens) + require.NotNil(t, result.FirstTokenMs) + require.Equal(t, int64(1), result.ClientToUpstreamFrames) + require.Equal(t, int64(1), result.UpstreamToClientFrames) + require.Equal(t, int64(0), result.DroppedDownstreamFrames) + + upstreamWrites := upstreamConn.Writes() + require.Len(t, upstreamWrites, 1) + require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType) + require.JSONEq(t, string(firstPayload), string(upstreamWrites[0].payload)) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageText, clientWrites[0].msgType) + require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`, string(clientWrites[0].payload)) +} + +func TestRelay_FunctionCallOutputBytesPreserved(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + + upstreamWrites := upstreamConn.Writes() + require.Len(t, upstreamWrites, 1) + require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType) + require.Equal(t, firstPayload, upstreamWrites[0].payload) +} + +func TestRelay_UpstreamDisconnect(t *testing.T) { + t.Parallel() + + // 上游立即关闭(EOF),客户端不发送额外帧 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + // 上游 EOF 属于 disconnect,标记为 graceful + require.Nil(t, relayExit, "上游 EOF 应被视为 graceful disconnect") + require.Equal(t, "gpt-4o", result.RequestModel) +} + +func TestRelay_ClientDisconnect(t *testing.T) { + t.Parallel() + + // 客户端立即关闭(EOF),上游阻塞读取直到 context 取消 + clientConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit, "客户端 EOF 应返回可观测的中断状态") + require.Equal(t, "client_disconnected", relayExit.Stage) + require.Equal(t, "gpt-4o", result.RequestModel) +} + +func TestRelay_ClientDisconnect_DrainCapturesLateUsage(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, true) + upstreamBase := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`), + }, + }, true) + upstreamConn := &delayedReadFrameConn{ + base: upstreamBase, + firstDelay: 80 * time.Millisecond, + } + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + UpstreamDrainTimeout: 400 * time.Millisecond, + }) + require.NotNil(t, relayExit) + require.Equal(t, "client_disconnected", relayExit.Stage) + require.Equal(t, "resp_drain", result.RequestID) + require.Equal(t, "response.completed", result.TerminalEventType) + require.Equal(t, 6, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.Equal(t, 1, result.Usage.CacheReadInputTokens) + require.Equal(t, int64(1), result.ClientToUpstreamFrames) + require.Equal(t, int64(0), result.UpstreamToClientFrames) + require.Equal(t, int64(1), result.DroppedDownstreamFrames) +} + +func TestRelay_IdleTimeout(t *testing.T) { + t.Parallel() + + // 客户端和上游都不发送帧,idle timeout 应触发 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 使用快进时间来加速 idle timeout + now := time.Now() + callCount := 0 + nowFn := func() time.Time { + callCount++ + // 前几次调用返回正常时间(初始化阶段),之后快进 + if callCount <= 5 { + return now + } + return now.Add(time.Hour) // 快进到超时 + } + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + IdleTimeout: 2 * time.Second, + Now: nowFn, + }) + require.NotNil(t, relayExit, "应因 idle timeout 退出") + require.Equal(t, "idle_timeout", relayExit.Stage) + require.Equal(t, "gpt-4o", result.RequestModel) +} + +func TestRelay_IdleTimeoutDoesNotCloseClientOnError(t *testing.T) { + t.Parallel() + + clientConn := &closeSpyFrameConn{} + upstreamConn := &closeSpyFrameConn{} + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + now := time.Now() + callCount := 0 + nowFn := func() time.Time { + callCount++ + if callCount <= 5 { + return now + } + return now.Add(time.Hour) + } + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + IdleTimeout: 2 * time.Second, + Now: nowFn, + }) + require.NotNil(t, relayExit, "应因 idle timeout 退出") + require.Equal(t, "idle_timeout", relayExit.Stage) + require.Zero(t, clientConn.CloseCalls(), "错误路径不应提前关闭客户端连接,交给上层决定 close code") + require.GreaterOrEqual(t, upstreamConn.CloseCalls(), int32(1)) +} + +func TestRelay_NilConnections(t *testing.T) { + t.Parallel() + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx := context.Background() + + t.Run("nil client conn", func(t *testing.T) { + upstreamConn := newPassthroughTestFrameConn(nil, true) + _, relayExit := Relay(ctx, nil, upstreamConn, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit) + require.Equal(t, "relay_init", relayExit.Stage) + require.Contains(t, relayExit.Err.Error(), "nil") + }) + + t.Run("nil upstream conn", func(t *testing.T) { + clientConn := newPassthroughTestFrameConn(nil, true) + _, relayExit := Relay(ctx, clientConn, nil, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit) + require.Equal(t, "relay_init", relayExit.Stage) + require.Contains(t, relayExit.Err.Error(), "nil") + }) +} + +func TestRelay_MultipleUpstreamMessages(t *testing.T) { + t.Parallel() + + // 上游发送多个事件(delta + completed),验证多帧中继和 usage 聚合 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":"Hello"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + require.Equal(t, "resp_multi", result.RequestID) + require.Equal(t, "response.completed", result.TerminalEventType) + require.Equal(t, 10, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) + require.NotNil(t, result.FirstTokenMs) + + // 验证所有 3 个上游帧都转发给了客户端 + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 3) +} + +func TestRelay_OnTurnComplete_PerTerminalEvent(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + turns := make([]RelayTurnResult, 0, 2) + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + OnTurnComplete: func(turn RelayTurnResult) { + turns = append(turns, turn) + }, + }) + require.Nil(t, relayExit) + require.Len(t, turns, 2) + require.Equal(t, "resp_turn_1", turns[0].RequestID) + require.Equal(t, "response.completed", turns[0].TerminalEventType) + require.Equal(t, 2, turns[0].Usage.InputTokens) + require.Equal(t, 1, turns[0].Usage.OutputTokens) + require.Equal(t, "resp_turn_2", turns[1].RequestID) + require.Equal(t, "response.failed", turns[1].TerminalEventType) + require.Equal(t, 3, turns[1].Usage.InputTokens) + require.Equal(t, 4, turns[1].Usage.OutputTokens) + require.Equal(t, 5, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) +} + +func TestRelay_OnTurnComplete_ProvidesTurnMetrics(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + base := time.Unix(0, 0) + var nowTick atomic.Int64 + nowFn := func() time.Time { + step := nowTick.Add(1) + return base.Add(time.Duration(step) * 5 * time.Millisecond) + } + + var turn RelayTurnResult + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + Now: nowFn, + OnTurnComplete: func(current RelayTurnResult) { + turn = current + }, + }) + require.Nil(t, relayExit) + require.Equal(t, "resp_metric", turn.RequestID) + require.Equal(t, "response.completed", turn.TerminalEventType) + require.NotNil(t, turn.FirstTokenMs) + require.GreaterOrEqual(t, *turn.FirstTokenMs, 0) + require.Greater(t, turn.Duration.Milliseconds(), int64(0)) + require.NotNil(t, result.FirstTokenMs) + require.Greater(t, result.Duration.Milliseconds(), int64(0)) +} + +func TestRelay_BinaryFramePassthrough(t *testing.T) { + t.Parallel() + + // 验证 binary frame 被透传但不进行 usage 解析 + binaryPayload := []byte{0x00, 0x01, 0x02, 0x03} + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageBinary, + payload: binaryPayload, + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + // binary frame 不解析 usage + require.Equal(t, 0, result.Usage.InputTokens) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType) + require.Equal(t, binaryPayload, clientWrites[0].payload) +} + +func TestRelay_BinaryJSONFrameSkipsObservation(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageBinary, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + require.Equal(t, 0, result.Usage.InputTokens) + require.Equal(t, "", result.RequestID) + require.Equal(t, "", result.TerminalEventType) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType) +} + +func TestRelay_UpstreamErrorEventPassthroughRaw(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + errorEvent := []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: errorEvent, + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageText, clientWrites[0].msgType) + require.Equal(t, errorEvent, clientWrites[0].payload) +} + +func TestRelay_PreservesFirstMessageType(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + FirstMessageType: coderws.MessageBinary, + }) + require.Nil(t, relayExit) + + upstreamWrites := upstreamConn.Writes() + require.Len(t, upstreamWrites, 1) + require.Equal(t, coderws.MessageBinary, upstreamWrites[0].msgType) + require.Equal(t, firstPayload, upstreamWrites[0].payload) +} + +func TestRelay_UsageParseFailureDoesNotBlockRelay(t *testing.T) { + baseline := SnapshotMetrics().UsageParseFailureTotal + + // 上游发送无效 JSON(非 usage 格式),不应影响透传 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + // usage 解析失败,值为 0 但不影响透传 + require.Equal(t, 0, result.Usage.InputTokens) + require.Equal(t, "response.completed", result.TerminalEventType) + + // 帧仍然被转发 + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.GreaterOrEqual(t, SnapshotMetrics().UsageParseFailureTotal, baseline+1) +} + +func TestRelay_WriteUpstreamFirstMessageFails(t *testing.T) { + t.Parallel() + + // 上游连接立即关闭,首包写入失败 + upstreamConn := newPassthroughTestFrameConn(nil, true) + _ = upstreamConn.Close() + + // 覆盖 WriteFrame 使其返回错误 + errConn := &errorOnWriteFrameConn{} + clientConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, errConn, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit) + require.Equal(t, "write_upstream", relayExit.Stage) +} + +func TestRelay_ContextCanceled(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + + // 立即取消 context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + // context 取消导致写首包失败 + require.NotNil(t, relayExit) +} + +func TestRelay_TraceEvents_ContainsLifecycleStages(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + stages := make([]string, 0, 8) + var stagesMu sync.Mutex + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + OnTrace: func(event RelayTraceEvent) { + stagesMu.Lock() + stages = append(stages, event.Stage) + stagesMu.Unlock() + }, + }) + require.Nil(t, relayExit) + stagesMu.Lock() + capturedStages := append([]string(nil), stages...) + stagesMu.Unlock() + require.Contains(t, capturedStages, "relay_start") + require.Contains(t, capturedStages, "write_first_message_ok") + require.Contains(t, capturedStages, "first_exit") + require.Contains(t, capturedStages, "relay_complete") +} + +func TestRelay_TraceEvents_IdleTimeout(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + now := time.Now() + callCount := 0 + nowFn := func() time.Time { + callCount++ + if callCount <= 5 { + return now + } + return now.Add(time.Hour) + } + + stages := make([]string, 0, 8) + var stagesMu sync.Mutex + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + IdleTimeout: 2 * time.Second, + Now: nowFn, + OnTrace: func(event RelayTraceEvent) { + stagesMu.Lock() + stages = append(stages, event.Stage) + stagesMu.Unlock() + }, + }) + require.NotNil(t, relayExit) + require.Equal(t, "idle_timeout", relayExit.Stage) + stagesMu.Lock() + capturedStages := append([]string(nil), stages...) + stagesMu.Unlock() + require.Contains(t, capturedStages, "idle_timeout_triggered") + require.Contains(t, capturedStages, "relay_exit") +} + +// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。 +type errorOnWriteFrameConn struct{} + +func (c *errorOnWriteFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + <-ctx.Done() + return coderws.MessageText, nil, ctx.Err() +} + +func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageType, _ []byte) error { + return errors.New("write failed: connection refused") +} + +func (c *errorOnWriteFrameConn) Close() error { + return nil +} diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go new file mode 100644 index 00000000..cda2e351 --- /dev/null +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -0,0 +1,372 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "sync/atomic" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +type openAIWSClientFrameConn struct { + conn *coderws.Conn +} + +const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2" + +var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil) + +func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.conn == nil { + return coderws.MessageText, nil, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Read(ctx) +} + +func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Write(ctx, msgType, payload) +} + +func (c *openAIWSClientFrameConn) Close() error { + if c == nil || c.conn == nil { + return nil + } + _ = c.conn.Close(coderws.StatusNormalClosure, "") + _ = c.conn.CloseNow() + return nil +} + +func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *Account, + token string, + firstClientMessage []byte, + hooks *OpenAIWSIngressHooks, + wsDecision OpenAIWSProtocolDecision, +) error { + if s == nil { + return errors.New("service is nil") + } + if clientConn == nil { + return errors.New("client websocket is nil") + } + if account == nil { + return errors.New("account is nil") + } + if strings.TrimSpace(token) == "" { + return errors.New("token is empty") + } + requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String()) + requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage) + requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String()) + logOpenAIWSV2Passthrough( + "relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d", + account.ID, + truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen), + openaiwsv2RelayMessageTypeName(coderws.MessageText), + len(firstClientMessage), + ) + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return fmt.Errorf("build ws url: %w", err) + } + wsHost := "-" + wsPath := "-" + if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil { + wsHost = normalizeOpenAIWSLogValue(parsedURL.Host) + wsPath = normalizeOpenAIWSLogValue(parsedURL.Path) + } + logOpenAIWSV2Passthrough( + "relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + + isCodexCLI := false + if c != nil { + isCodexCLI = openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) + } + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + isCodexCLI = true + } + headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "") + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + dialer := s.getOpenAIWSPassthroughDialer() + if dialer == nil { + return errors.New("openai ws passthrough dialer is nil") + } + + dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout()) + defer cancelDial() + upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL) + if err != nil { + logOpenAIWSV2Passthrough( + "relay_dial_failed account_id=%d status_code=%d err=%s", + account.ID, + statusCode, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders) + } + defer func() { + _ = upstreamConn.Close() + }() + logOpenAIWSV2Passthrough( + "relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s", + account.ID, + statusCode, + openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"), + ) + + upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn) + if !ok { + return errors.New("openai ws passthrough upstream connection does not support frame relay") + } + + completedTurns := atomic.Int32{} + relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{ + Ctx: ctx, + ClientConn: &openAIWSClientFrameConn{conn: clientConn}, + UpstreamConn: upstreamFrameConn, + FirstClientMessage: firstClientMessage, + Options: openaiwsv2.RelayOptions{ + WriteTimeout: s.openAIWSWriteTimeout(), + IdleTimeout: s.openAIWSPassthroughIdleTimeout(), + FirstMessageType: coderws.MessageText, + OnUsageParseFailure: func(eventType string, usageRaw string) { + logOpenAIWSV2Passthrough( + "usage_parse_failed event_type=%s usage_raw=%s", + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen), + ) + }, + OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) { + turnNo := int(completedTurns.Add(1)) + turnResult := &OpenAIForwardResult{ + RequestID: turn.RequestID, + Usage: OpenAIUsage{ + InputTokens: turn.Usage.InputTokens, + OutputTokens: turn.Usage.OutputTokens, + CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens, + CacheReadInputTokens: turn.Usage.CacheReadInputTokens, + }, + Model: turn.RequestModel, + ServiceTier: requestServiceTier, + Stream: true, + OpenAIWSMode: true, + ResponseHeaders: cloneHeader(handshakeHeaders), + Duration: turn.Duration, + FirstTokenMs: turn.FirstTokenMs, + } + logOpenAIWSV2Passthrough( + "relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d", + account.ID, + turnNo, + truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen), + turnResult.Duration.Milliseconds(), + openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs), + turnResult.Usage.InputTokens, + turnResult.Usage.OutputTokens, + turnResult.Usage.CacheReadInputTokens, + ) + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turnNo, turnResult, nil) + } + }, + OnTrace: func(event openaiwsv2.RelayTraceEvent) { + logOpenAIWSV2Passthrough( + "relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s", + account.ID, + truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen), + event.PayloadBytes, + event.Graceful, + event.WroteDownstream, + truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen), + ) + }, + }, + }) + + result := &OpenAIForwardResult{ + RequestID: relayResult.RequestID, + Usage: OpenAIUsage{ + InputTokens: relayResult.Usage.InputTokens, + OutputTokens: relayResult.Usage.OutputTokens, + CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens, + CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, + }, + Model: relayResult.RequestModel, + ServiceTier: requestServiceTier, + Stream: true, + OpenAIWSMode: true, + ResponseHeaders: cloneHeader(handshakeHeaders), + Duration: relayResult.Duration, + FirstTokenMs: relayResult.FirstTokenMs, + } + + turnCount := int(completedTurns.Load()) + if relayExit == nil { + logOpenAIWSV2Passthrough( + "relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d", + account.ID, + truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen), + result.Duration.Milliseconds(), + relayResult.ClientToUpstreamFrames, + relayResult.UpstreamToClientFrames, + relayResult.DroppedDownstreamFrames, + turnCount, + ) + // 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。 + if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(1, result, nil) + } + return nil + } + logOpenAIWSV2Passthrough( + "relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d", + account.ID, + truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen), + relayExit.WroteDownstream, + truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen), + result.Duration.Milliseconds(), + relayResult.ClientToUpstreamFrames, + relayResult.UpstreamToClientFrames, + relayResult.DroppedDownstreamFrames, + turnCount, + ) + + relayErr := relayExit.Err + if relayExit.Stage == "idle_timeout" { + relayErr = NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "client websocket idle timeout", + relayErr, + ) + } + turnErr := wrapOpenAIWSIngressTurnError( + relayExit.Stage, + relayErr, + relayExit.WroteDownstream, + ) + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turnCount+1, nil, turnErr) + } + return turnErr +} + +func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError( + err error, + statusCode int, + handshakeHeaders http.Header, +) error { + if err == nil { + return nil + } + wrappedErr := err + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) { + wrappedErr = &openAIWSDialError{ + StatusCode: statusCode, + ResponseHeaders: cloneHeader(handshakeHeaders), + Err: err, + } + } + + if errors.Is(err, context.Canceled) { + return err + } + if errors.Is(err, context.DeadlineExceeded) { + return NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket connect timeout", + wrappedErr, + ) + } + if statusCode == http.StatusTooManyRequests { + return NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket is busy, please retry later", + wrappedErr, + ) + } + if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream websocket authentication failed", + wrappedErr, + ) + } + if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream websocket handshake rejected", + wrappedErr, + ) + } + return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr) +} + +func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string { + switch msgType { + case coderws.MessageText: + return "text" + case coderws.MessageBinary: + return "binary" + default: + return fmt.Sprintf("unknown(%d)", msgType) + } +} + +func relayErrorText(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func openAIWSFirstTokenMsForLog(firstTokenMs *int) int { + if firstTokenMs == nil { + return -1 + } + return *firstTokenMs +} + +func logOpenAIWSV2Passthrough(format string, args ...any) { + logger.LegacyPrintf( + "service.openai_ws_v2", + "[OpenAI WS v2 passthrough] %s "+format, + append([]any{openaiWSV2PassthroughModeFields}, args...)..., + ) +} diff --git a/backend/internal/service/ops_aggregation_service.go b/backend/internal/service/ops_aggregation_service.go index ec77fe12..89076ce2 100644 --- a/backend/internal/service/ops_aggregation_service.go +++ b/backend/internal/service/ops_aggregation_service.go @@ -23,7 +23,7 @@ const ( opsAggDailyInterval = 1 * time.Hour // Keep in sync with ops retention target (vNext default 30d). - opsAggBackfillWindow = 30 * 24 * time.Hour + opsAggBackfillWindow = 1 * time.Hour // Recompute overlap to absorb late-arriving rows near boundaries. opsAggHourlyOverlap = 2 * time.Hour @@ -36,7 +36,7 @@ const ( // that may still receive late inserts. opsAggSafeDelay = 5 * time.Minute - opsAggMaxQueryTimeout = 3 * time.Second + opsAggMaxQueryTimeout = 5 * time.Second opsAggHourlyTimeout = 5 * time.Minute opsAggDailyTimeout = 2 * time.Minute diff --git a/backend/internal/service/ops_alert_evaluator_service.go b/backend/internal/service/ops_alert_evaluator_service.go index 169a5e32..88883180 100644 --- a/backend/internal/service/ops_alert_evaluator_service.go +++ b/backend/internal/service/ops_alert_evaluator_service.go @@ -506,6 +506,48 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric( return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool { return acc.HasError && acc.TempUnschedulableUntil == nil })), true + case "group_rate_limit_ratio": + if groupID == nil || *groupID <= 0 { + return 0, false + } + if s == nil || s.opsService == nil { + return 0, false + } + availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID) + if err != nil || availability == nil { + return 0, false + } + if availability.Group == nil || availability.Group.TotalAccounts <= 0 { + return 0, true + } + return (float64(availability.Group.RateLimitCount) / float64(availability.Group.TotalAccounts)) * 100, true + case "account_error_ratio": + if s == nil || s.opsService == nil { + return 0, false + } + availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID) + if err != nil || availability == nil { + return 0, false + } + total := int64(len(availability.Accounts)) + if total <= 0 { + return 0, true + } + errorCount := countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool { + return acc.HasError && acc.TempUnschedulableUntil == nil + }) + return (float64(errorCount) / float64(total)) * 100, true + case "overload_account_count": + if s == nil || s.opsService == nil { + return 0, false + } + availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID) + if err != nil || availability == nil { + return 0, false + } + return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool { + return acc.IsOverloaded + })), true } overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{ diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go index 92b37e73..a571dd4d 100644 --- a/backend/internal/service/ops_concurrency.go +++ b/backend/internal/service/ops_concurrency.go @@ -64,8 +64,12 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts if acc.ID <= 0 { continue } - if prev, ok := unique[acc.ID]; !ok || acc.Concurrency > prev { - unique[acc.ID] = acc.Concurrency + c := acc.Concurrency + if c <= 0 { + c = 1 + } + if prev, ok := unique[acc.ID]; !ok || c > prev { + unique[acc.ID] = c } } diff --git a/backend/internal/service/ops_dashboard.go b/backend/internal/service/ops_dashboard.go index 31822ba8..6f70c75c 100644 --- a/backend/internal/service/ops_dashboard.go +++ b/backend/internal/service/ops_dashboard.go @@ -31,6 +31,10 @@ func (s *OpsService) GetDashboardOverview(ctx context.Context, filter *OpsDashbo filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) overview, err := s.opsRepo.GetDashboardOverview(ctx, filter) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + overview, err = s.opsRepo.GetDashboardOverview(ctx, rawFilter) + } if err != nil { if errors.Is(err, ErrOpsPreaggregatedNotPopulated) { return nil, infraerrors.Conflict("OPS_PREAGG_NOT_READY", "Pre-aggregated ops metrics are not populated yet") diff --git a/backend/internal/service/ops_errors.go b/backend/internal/service/ops_errors.go index 76b5ce8b..01671c1e 100644 --- a/backend/internal/service/ops_errors.go +++ b/backend/internal/service/ops_errors.go @@ -22,7 +22,14 @@ func (s *OpsService) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilt if filter.StartTime.After(filter.EndTime) { return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") } - return s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds) + filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) + + result, err := s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + return s.opsRepo.GetErrorTrend(ctx, rawFilter, bucketSeconds) + } + return result, err } func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) { @@ -41,5 +48,12 @@ func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashbo if filter.StartTime.After(filter.EndTime) { return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") } - return s.opsRepo.GetErrorDistribution(ctx, filter) + filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) + + result, err := s.opsRepo.GetErrorDistribution(ctx, filter) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + return s.opsRepo.GetErrorDistribution(ctx, rawFilter) + } + return result, err } diff --git a/backend/internal/service/ops_histograms.go b/backend/internal/service/ops_histograms.go index 9f5b514f..c555dbfc 100644 --- a/backend/internal/service/ops_histograms.go +++ b/backend/internal/service/ops_histograms.go @@ -22,5 +22,12 @@ func (s *OpsService) GetLatencyHistogram(ctx context.Context, filter *OpsDashboa if filter.StartTime.After(filter.EndTime) { return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") } - return s.opsRepo.GetLatencyHistogram(ctx, filter) + filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) + + result, err := s.opsRepo.GetLatencyHistogram(ctx, filter) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + return s.opsRepo.GetLatencyHistogram(ctx, rawFilter) + } + return result, err } diff --git a/backend/internal/service/ops_metrics_collector.go b/backend/internal/service/ops_metrics_collector.go index 30adaae0..f93481e7 100644 --- a/backend/internal/service/ops_metrics_collector.go +++ b/backend/internal/service/ops_metrics_collector.go @@ -389,13 +389,9 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con if acc.ID <= 0 { continue } - maxConc := acc.Concurrency - if maxConc < 0 { - maxConc = 0 - } batch = append(batch, AccountWithConcurrency{ ID: acc.ID, - MaxConcurrency: maxConc, + MaxConcurrency: acc.Concurrency, }) } if len(batch) == 0 { diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index f3633eae..0ce9d425 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -7,6 +7,7 @@ import ( type OpsRepository interface { InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) + BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error) diff --git a/backend/internal/service/ops_query_mode.go b/backend/internal/service/ops_query_mode.go index e6fa9c1e..fa97f358 100644 --- a/backend/internal/service/ops_query_mode.go +++ b/backend/internal/service/ops_query_mode.go @@ -38,3 +38,18 @@ func (m OpsQueryMode) IsValid() bool { return false } } + +func shouldFallbackOpsPreagg(filter *OpsDashboardFilter, err error) bool { + return filter != nil && + filter.QueryMode == OpsQueryModeAuto && + errors.Is(err, ErrOpsPreaggregatedNotPopulated) +} + +func cloneOpsFilterWithMode(filter *OpsDashboardFilter, mode OpsQueryMode) *OpsDashboardFilter { + if filter == nil { + return nil + } + cloned := *filter + cloned.QueryMode = mode + return &cloned +} diff --git a/backend/internal/service/ops_query_mode_test.go b/backend/internal/service/ops_query_mode_test.go new file mode 100644 index 00000000..26c4b730 --- /dev/null +++ b/backend/internal/service/ops_query_mode_test.go @@ -0,0 +1,66 @@ +//go:build unit + +package service + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestShouldFallbackOpsPreagg(t *testing.T) { + preaggErr := ErrOpsPreaggregatedNotPopulated + otherErr := errors.New("some other error") + + autoFilter := &OpsDashboardFilter{QueryMode: OpsQueryModeAuto} + rawFilter := &OpsDashboardFilter{QueryMode: OpsQueryModeRaw} + preaggFilter := &OpsDashboardFilter{QueryMode: OpsQueryModePreagg} + + tests := []struct { + name string + filter *OpsDashboardFilter + err error + want bool + }{ + {"auto mode + preagg error => fallback", autoFilter, preaggErr, true}, + {"auto mode + other error => no fallback", autoFilter, otherErr, false}, + {"auto mode + nil error => no fallback", autoFilter, nil, false}, + {"raw mode + preagg error => no fallback", rawFilter, preaggErr, false}, + {"preagg mode + preagg error => no fallback", preaggFilter, preaggErr, false}, + {"nil filter => no fallback", nil, preaggErr, false}, + {"wrapped preagg error => fallback", autoFilter, errors.Join(preaggErr, otherErr), true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := shouldFallbackOpsPreagg(tc.filter, tc.err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestCloneOpsFilterWithMode(t *testing.T) { + t.Run("nil filter returns nil", func(t *testing.T) { + require.Nil(t, cloneOpsFilterWithMode(nil, OpsQueryModeRaw)) + }) + + t.Run("cloned filter has new mode", func(t *testing.T) { + groupID := int64(42) + original := &OpsDashboardFilter{ + StartTime: time.Now(), + EndTime: time.Now().Add(time.Hour), + Platform: "anthropic", + GroupID: &groupID, + QueryMode: OpsQueryModeAuto, + } + + cloned := cloneOpsFilterWithMode(original, OpsQueryModeRaw) + require.Equal(t, OpsQueryModeRaw, cloned.QueryMode) + require.Equal(t, OpsQueryModeAuto, original.QueryMode, "original should not be modified") + require.Equal(t, original.Platform, cloned.Platform) + require.Equal(t, original.StartTime, cloned.StartTime) + require.Equal(t, original.GroupID, cloned.GroupID) + }) +} diff --git a/backend/internal/service/ops_repo_mock_test.go b/backend/internal/service/ops_repo_mock_test.go index e250dea3..c8c66ec6 100644 --- a/backend/internal/service/ops_repo_mock_test.go +++ b/backend/internal/service/ops_repo_mock_test.go @@ -7,6 +7,8 @@ import ( // opsRepoMock is a test-only OpsRepository implementation with optional function hooks. type opsRepoMock struct { + InsertErrorLogFn func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) + BatchInsertErrorLogsFn func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) @@ -14,9 +16,19 @@ type opsRepoMock struct { } func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) { + if m.InsertErrorLogFn != nil { + return m.InsertErrorLogFn(ctx, input) + } return 0, nil } +func (m *opsRepoMock) BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) { + if m.BatchInsertErrorLogsFn != nil { + return m.BatchInsertErrorLogsFn(ctx, inputs) + } + return int64(len(inputs)), nil +} + func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) { return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil } diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index f0daa3e2..fdabbafd 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -467,7 +467,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: selErr.Error()} } if selection == nil || selection.Account == nil { - return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "no available accounts"} + return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: ErrNoAvailableAccounts.Error()} } account := selection.Account diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go index 767d1704..29f0aa8b 100644 --- a/backend/internal/service/ops_service.go +++ b/backend/internal/service/ops_service.go @@ -121,14 +121,74 @@ func (s *OpsService) IsMonitoringEnabled(ctx context.Context) bool { } func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) error { - if entry == nil { + prepared, ok, err := s.prepareErrorLogInput(ctx, entry, rawRequestBody) + if err != nil { + log.Printf("[Ops] RecordError prepare failed: %v", err) + return err + } + if !ok { return nil } + + if _, err := s.opsRepo.InsertErrorLog(ctx, prepared); err != nil { + // Never bubble up to gateway; best-effort logging. + log.Printf("[Ops] RecordError failed: %v", err) + return err + } + return nil +} + +func (s *OpsService) RecordErrorBatch(ctx context.Context, entries []*OpsInsertErrorLogInput) error { + if len(entries) == 0 { + return nil + } + prepared := make([]*OpsInsertErrorLogInput, 0, len(entries)) + for _, entry := range entries { + item, ok, err := s.prepareErrorLogInput(ctx, entry, nil) + if err != nil { + log.Printf("[Ops] RecordErrorBatch prepare failed: %v", err) + continue + } + if ok { + prepared = append(prepared, item) + } + } + if len(prepared) == 0 { + return nil + } + if len(prepared) == 1 { + _, err := s.opsRepo.InsertErrorLog(ctx, prepared[0]) + if err != nil { + log.Printf("[Ops] RecordErrorBatch single insert failed: %v", err) + } + return err + } + + if _, err := s.opsRepo.BatchInsertErrorLogs(ctx, prepared); err != nil { + log.Printf("[Ops] RecordErrorBatch failed, fallback to single inserts: %v", err) + var firstErr error + for _, entry := range prepared { + if _, insertErr := s.opsRepo.InsertErrorLog(ctx, entry); insertErr != nil { + log.Printf("[Ops] RecordErrorBatch fallback insert failed: %v", insertErr) + if firstErr == nil { + firstErr = insertErr + } + } + } + return firstErr + } + return nil +} + +func (s *OpsService) prepareErrorLogInput(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) (*OpsInsertErrorLogInput, bool, error) { + if entry == nil { + return nil, false, nil + } if !s.IsMonitoringEnabled(ctx) { - return nil + return nil, false, nil } if s.opsRepo == nil { - return nil + return nil, false, nil } // Ensure timestamps are always populated. @@ -185,85 +245,88 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn } } - // Sanitize + serialize upstream error events list. - if len(entry.UpstreamErrors) > 0 { - const maxEvents = 32 - events := entry.UpstreamErrors - if len(events) > maxEvents { - events = events[len(events)-maxEvents:] + if err := sanitizeOpsUpstreamErrors(entry); err != nil { + return nil, false, err + } + + return entry, true, nil +} + +func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error { + if entry == nil || len(entry.UpstreamErrors) == 0 { + return nil + } + + const maxEvents = 32 + events := entry.UpstreamErrors + if len(events) > maxEvents { + events = events[len(events)-maxEvents:] + } + + sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events)) + for _, ev := range events { + if ev == nil { + continue + } + out := *ev + + out.Platform = strings.TrimSpace(out.Platform) + out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128) + out.Kind = truncateString(strings.TrimSpace(out.Kind), 64) + + if out.AccountID < 0 { + out.AccountID = 0 + } + if out.UpstreamStatusCode < 0 { + out.UpstreamStatusCode = 0 + } + if out.AtUnixMs < 0 { + out.AtUnixMs = 0 } - sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events)) - for _, ev := range events { - if ev == nil { - continue - } - out := *ev + msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message)) + msg = truncateString(msg, 2048) + out.Message = msg - out.Platform = strings.TrimSpace(out.Platform) - out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128) - out.Kind = truncateString(strings.TrimSpace(out.Kind), 64) + detail := strings.TrimSpace(out.Detail) + if detail != "" { + // Keep upstream detail small; request bodies are not stored here, only upstream error payloads. + sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes) + out.Detail = sanitizedDetail + } else { + out.Detail = "" + } - if out.AccountID < 0 { - out.AccountID = 0 - } - if out.UpstreamStatusCode < 0 { - out.UpstreamStatusCode = 0 - } - if out.AtUnixMs < 0 { - out.AtUnixMs = 0 - } - - msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message)) - msg = truncateString(msg, 2048) - out.Message = msg - - detail := strings.TrimSpace(out.Detail) - if detail != "" { - // Keep upstream detail small; request bodies are not stored here, only upstream error payloads. - sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes) - out.Detail = sanitizedDetail - } else { - out.Detail = "" - } - - out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody) - if out.UpstreamRequestBody != "" { - // Reuse the same sanitization/trimming strategy as request body storage. - // Keep it small so it is safe to persist in ops_error_logs JSON. - sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024) - if sanitized != "" { - out.UpstreamRequestBody = sanitized - if truncated { - out.Kind = strings.TrimSpace(out.Kind) - if out.Kind == "" { - out.Kind = "upstream" - } - out.Kind = out.Kind + ":request_body_truncated" + out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody) + if out.UpstreamRequestBody != "" { + // Reuse the same sanitization/trimming strategy as request body storage. + // Keep it small so it is safe to persist in ops_error_logs JSON. + sanitizedBody, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024) + if sanitizedBody != "" { + out.UpstreamRequestBody = sanitizedBody + if truncated { + out.Kind = strings.TrimSpace(out.Kind) + if out.Kind == "" { + out.Kind = "upstream" } - } else { - out.UpstreamRequestBody = "" + out.Kind = out.Kind + ":request_body_truncated" } + } else { + out.UpstreamRequestBody = "" } - - // Drop fully-empty events (can happen if only status code was known). - if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" { - continue - } - - evCopy := out - sanitized = append(sanitized, &evCopy) } - entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized) - entry.UpstreamErrors = nil + // Drop fully-empty events (can happen if only status code was known). + if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" { + continue + } + + evCopy := out + sanitized = append(sanitized, &evCopy) } - if _, err := s.opsRepo.InsertErrorLog(ctx, entry); err != nil { - // Never bubble up to gateway; best-effort logging. - log.Printf("[Ops] RecordError failed: %v", err) - return err - } + entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized) + entry.UpstreamErrors = nil return nil } diff --git a/backend/internal/service/ops_service_batch_test.go b/backend/internal/service/ops_service_batch_test.go new file mode 100644 index 00000000..f3a14d7f --- /dev/null +++ b/backend/internal/service/ops_service_batch_test.go @@ -0,0 +1,103 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestOpsServiceRecordErrorBatch_SanitizesAndBatches(t *testing.T) { + t.Parallel() + + var captured []*OpsInsertErrorLogInput + repo := &opsRepoMock{ + BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) { + captured = append(captured, inputs...) + return int64(len(inputs)), nil + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + + msg := " upstream failed: https://example.com?access_token=secret-value " + detail := `{"authorization":"Bearer secret-token"}` + entries := []*OpsInsertErrorLogInput{ + { + ErrorBody: `{"error":"bad","access_token":"secret"}`, + UpstreamStatusCode: intPtr(-10), + UpstreamErrorMessage: strPtr(msg), + UpstreamErrorDetail: strPtr(detail), + UpstreamErrors: []*OpsUpstreamErrorEvent{ + { + AccountID: -2, + UpstreamStatusCode: 429, + Message: " token leaked ", + Detail: `{"refresh_token":"secret"}`, + UpstreamRequestBody: `{"api_key":"secret","messages":[{"role":"user","content":"hello"}]}`, + }, + }, + }, + { + ErrorPhase: "upstream", + ErrorType: "upstream_error", + CreatedAt: time.Now().UTC(), + }, + } + + require.NoError(t, svc.RecordErrorBatch(context.Background(), entries)) + require.Len(t, captured, 2) + + first := captured[0] + require.Equal(t, "internal", first.ErrorPhase) + require.Equal(t, "api_error", first.ErrorType) + require.Nil(t, first.UpstreamStatusCode) + require.NotNil(t, first.UpstreamErrorMessage) + require.NotContains(t, *first.UpstreamErrorMessage, "secret-value") + require.Contains(t, *first.UpstreamErrorMessage, "access_token=***") + require.NotNil(t, first.UpstreamErrorDetail) + require.NotContains(t, *first.UpstreamErrorDetail, "secret-token") + require.NotContains(t, first.ErrorBody, "secret") + require.Nil(t, first.UpstreamErrors) + require.NotNil(t, first.UpstreamErrorsJSON) + require.NotContains(t, *first.UpstreamErrorsJSON, "secret") + require.Contains(t, *first.UpstreamErrorsJSON, "[REDACTED]") + + second := captured[1] + require.Equal(t, "upstream", second.ErrorPhase) + require.Equal(t, "upstream_error", second.ErrorType) + require.False(t, second.CreatedAt.IsZero()) +} + +func TestOpsServiceRecordErrorBatch_FallsBackToSingleInsert(t *testing.T) { + t.Parallel() + + var ( + batchCalls int + singleCalls int + ) + repo := &opsRepoMock{ + BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) { + batchCalls++ + return 0, errors.New("batch failed") + }, + InsertErrorLogFn: func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) { + singleCalls++ + return int64(singleCalls), nil + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + + err := svc.RecordErrorBatch(context.Background(), []*OpsInsertErrorLogInput{ + {ErrorMessage: "first"}, + {ErrorMessage: "second"}, + }) + require.NoError(t, err) + require.Equal(t, 1, batchCalls) + require.Equal(t, 2, singleCalls) +} + +func strPtr(v string) *string { + return &v +} diff --git a/backend/internal/service/ops_settings.go b/backend/internal/service/ops_settings.go index a6a4a0d7..5871166c 100644 --- a/backend/internal/service/ops_settings.go +++ b/backend/internal/service/ops_settings.go @@ -368,11 +368,14 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings { Aggregation: OpsAggregationSettings{ AggregationEnabled: false, }, - IgnoreCountTokensErrors: false, - IgnoreContextCanceled: true, // Default to true - client disconnects are not errors - IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue - AutoRefreshEnabled: false, - AutoRefreshIntervalSec: 30, + IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略 + IgnoreContextCanceled: true, // Default to true - client disconnects are not errors + IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue + IgnoreInsufficientBalanceErrors: false, // 默认不忽略,余额不足可能需要关注 + DisplayOpenAITokenStats: false, + DisplayAlertEvents: true, + AutoRefreshEnabled: false, + AutoRefreshIntervalSec: 30, } } @@ -438,7 +441,7 @@ func (s *OpsService) GetOpsAdvancedSettings(ctx context.Context) (*OpsAdvancedSe return nil, err } - cfg := &OpsAdvancedSettings{} + cfg := defaultOpsAdvancedSettings() if err := json.Unmarshal([]byte(raw), cfg); err != nil { return defaultCfg, nil } diff --git a/backend/internal/service/ops_settings_advanced_test.go b/backend/internal/service/ops_settings_advanced_test.go new file mode 100644 index 00000000..06cc545b --- /dev/null +++ b/backend/internal/service/ops_settings_advanced_test.go @@ -0,0 +1,97 @@ +package service + +import ( + "context" + "encoding/json" + "testing" +) + +func TestGetOpsAdvancedSettings_DefaultHidesOpenAITokenStats(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{settingRepo: repo} + + cfg, err := svc.GetOpsAdvancedSettings(context.Background()) + if err != nil { + t.Fatalf("GetOpsAdvancedSettings() error = %v", err) + } + if cfg.DisplayOpenAITokenStats { + t.Fatalf("DisplayOpenAITokenStats = true, want false by default") + } + if !cfg.DisplayAlertEvents { + t.Fatalf("DisplayAlertEvents = false, want true by default") + } + if repo.setCalls != 1 { + t.Fatalf("expected defaults to be persisted once, got %d", repo.setCalls) + } +} + +func TestUpdateOpsAdvancedSettings_PersistsOpenAITokenStatsVisibility(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{settingRepo: repo} + + cfg := defaultOpsAdvancedSettings() + cfg.DisplayOpenAITokenStats = true + cfg.DisplayAlertEvents = false + + updated, err := svc.UpdateOpsAdvancedSettings(context.Background(), cfg) + if err != nil { + t.Fatalf("UpdateOpsAdvancedSettings() error = %v", err) + } + if !updated.DisplayOpenAITokenStats { + t.Fatalf("DisplayOpenAITokenStats = false, want true") + } + if updated.DisplayAlertEvents { + t.Fatalf("DisplayAlertEvents = true, want false") + } + + reloaded, err := svc.GetOpsAdvancedSettings(context.Background()) + if err != nil { + t.Fatalf("GetOpsAdvancedSettings() after update error = %v", err) + } + if !reloaded.DisplayOpenAITokenStats { + t.Fatalf("reloaded DisplayOpenAITokenStats = false, want true") + } + if reloaded.DisplayAlertEvents { + t.Fatalf("reloaded DisplayAlertEvents = true, want false") + } +} + +func TestGetOpsAdvancedSettings_BackfillsNewDisplayFlagsFromDefaults(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{settingRepo: repo} + + legacyCfg := map[string]any{ + "data_retention": map[string]any{ + "cleanup_enabled": false, + "cleanup_schedule": "0 2 * * *", + "error_log_retention_days": 30, + "minute_metrics_retention_days": 30, + "hourly_metrics_retention_days": 30, + }, + "aggregation": map[string]any{ + "aggregation_enabled": false, + }, + "ignore_count_tokens_errors": true, + "ignore_context_canceled": true, + "ignore_no_available_accounts": false, + "ignore_invalid_api_key_errors": false, + "auto_refresh_enabled": false, + "auto_refresh_interval_seconds": 30, + } + raw, err := json.Marshal(legacyCfg) + if err != nil { + t.Fatalf("marshal legacy config: %v", err) + } + repo.values[SettingKeyOpsAdvancedSettings] = string(raw) + + cfg, err := svc.GetOpsAdvancedSettings(context.Background()) + if err != nil { + t.Fatalf("GetOpsAdvancedSettings() error = %v", err) + } + if cfg.DisplayOpenAITokenStats { + t.Fatalf("DisplayOpenAITokenStats = true, want false default backfill") + } + if !cfg.DisplayAlertEvents { + t.Fatalf("DisplayAlertEvents = false, want true default backfill") + } +} diff --git a/backend/internal/service/ops_settings_models.go b/backend/internal/service/ops_settings_models.go index 8b5359e3..fa18b05f 100644 --- a/backend/internal/service/ops_settings_models.go +++ b/backend/internal/service/ops_settings_models.go @@ -92,14 +92,17 @@ type OpsAlertRuntimeSettings struct { // OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation). type OpsAdvancedSettings struct { - DataRetention OpsDataRetentionSettings `json:"data_retention"` - Aggregation OpsAggregationSettings `json:"aggregation"` - IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"` - IgnoreContextCanceled bool `json:"ignore_context_canceled"` - IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"` - IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"` - AutoRefreshEnabled bool `json:"auto_refresh_enabled"` - AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"` + DataRetention OpsDataRetentionSettings `json:"data_retention"` + Aggregation OpsAggregationSettings `json:"aggregation"` + IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"` + IgnoreContextCanceled bool `json:"ignore_context_canceled"` + IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"` + IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"` + IgnoreInsufficientBalanceErrors bool `json:"ignore_insufficient_balance_errors"` + DisplayOpenAITokenStats bool `json:"display_openai_token_stats"` + DisplayAlertEvents bool `json:"display_alert_events"` + AutoRefreshEnabled bool `json:"auto_refresh_enabled"` + AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"` } type OpsDataRetentionSettings struct { diff --git a/backend/internal/service/ops_trends.go b/backend/internal/service/ops_trends.go index ec55c6ce..22db72ef 100644 --- a/backend/internal/service/ops_trends.go +++ b/backend/internal/service/ops_trends.go @@ -22,5 +22,13 @@ func (s *OpsService) GetThroughputTrend(ctx context.Context, filter *OpsDashboar if filter.StartTime.After(filter.EndTime) { return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") } - return s.opsRepo.GetThroughputTrend(ctx, filter, bucketSeconds) + + filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) + + result, err := s.opsRepo.GetThroughputTrend(ctx, filter, bucketSeconds) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + return s.opsRepo.GetThroughputTrend(ctx, rawFilter, bucketSeconds) + } + return result, err } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 41e8b5eb..7ed4e7e4 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -21,18 +21,36 @@ import ( ) var ( - openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`) - openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) + openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`) + openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) + openAIGPT54FallbackPricing = &LiteLLMModelPricing{ + InputCostPerToken: 2.5e-06, // $2.5 per MTok + OutputCostPerToken: 1.5e-05, // $15 per MTok + CacheReadInputTokenCost: 2.5e-07, // $0.25 per MTok + LongContextInputTokenThreshold: 272000, + LongContextInputCostMultiplier: 2.0, + LongContextOutputCostMultiplier: 1.5, + LiteLLMProvider: "openai", + Mode: "chat", + SupportsPromptCaching: true, + } ) // LiteLLMModelPricing LiteLLM价格数据结构 // 只保留我们需要的字段,使用指针来处理可能缺失的值 type LiteLLMModelPricing struct { InputCostPerToken float64 `json:"input_cost_per_token"` + InputCostPerTokenPriority float64 `json:"input_cost_per_token_priority"` OutputCostPerToken float64 `json:"output_cost_per_token"` + OutputCostPerTokenPriority float64 `json:"output_cost_per_token_priority"` CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"` CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` + CacheReadInputTokenCostPriority float64 `json:"cache_read_input_token_cost_priority"` + LongContextInputTokenThreshold int `json:"long_context_input_token_threshold,omitempty"` + LongContextInputCostMultiplier float64 `json:"long_context_input_cost_multiplier,omitempty"` + LongContextOutputCostMultiplier float64 `json:"long_context_output_cost_multiplier,omitempty"` + SupportsServiceTier bool `json:"supports_service_tier"` LiteLLMProvider string `json:"litellm_provider"` Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` @@ -48,10 +66,14 @@ type PricingRemoteClient interface { // LiteLLMRawEntry 用于解析原始JSON数据 type LiteLLMRawEntry struct { InputCostPerToken *float64 `json:"input_cost_per_token"` + InputCostPerTokenPriority *float64 `json:"input_cost_per_token_priority"` OutputCostPerToken *float64 `json:"output_cost_per_token"` + OutputCostPerTokenPriority *float64 `json:"output_cost_per_token_priority"` CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"` CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` + CacheReadInputTokenCostPriority *float64 `json:"cache_read_input_token_cost_priority"` + SupportsServiceTier bool `json:"supports_service_tier"` LiteLLMProvider string `json:"litellm_provider"` Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` @@ -310,14 +332,21 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel LiteLLMProvider: entry.LiteLLMProvider, Mode: entry.Mode, SupportsPromptCaching: entry.SupportsPromptCaching, + SupportsServiceTier: entry.SupportsServiceTier, } if entry.InputCostPerToken != nil { pricing.InputCostPerToken = *entry.InputCostPerToken } + if entry.InputCostPerTokenPriority != nil { + pricing.InputCostPerTokenPriority = *entry.InputCostPerTokenPriority + } if entry.OutputCostPerToken != nil { pricing.OutputCostPerToken = *entry.OutputCostPerToken } + if entry.OutputCostPerTokenPriority != nil { + pricing.OutputCostPerTokenPriority = *entry.OutputCostPerTokenPriority + } if entry.CacheCreationInputTokenCost != nil { pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost } @@ -327,6 +356,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel if entry.CacheReadInputTokenCost != nil { pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost } + if entry.CacheReadInputTokenCostPriority != nil { + pricing.CacheReadInputTokenCostPriority = *entry.CacheReadInputTokenCostPriority + } if entry.OutputCostPerImage != nil { pricing.OutputCostPerImage = *entry.OutputCostPerImage } @@ -660,7 +692,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { // 2. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等) // 3. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) // 4. gpt-5.3-codex -> gpt-5.2-codex -// 5. 最终回退到 DefaultTestModel (gpt-5.1-codex) +// 5. gpt-5.4* -> 业务静态兜底价 +// 6. 最终回退到 DefaultTestModel (gpt-5.1-codex) func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { if strings.HasPrefix(model, "gpt-5.3-codex-spark") { if pricing, ok := s.pricingData["gpt-5.1-codex"]; ok { @@ -690,6 +723,12 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { } } + if strings.HasPrefix(model, "gpt-5.4") { + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4(static)")) + return openAIGPT54FallbackPricing + } + // 最终回退到 DefaultTestModel defaultModel := strings.ToLower(openai.DefaultTestModel) if pricing, ok := s.pricingData[defaultModel]; ok { diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index 127ff342..775024fd 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -1,11 +1,40 @@ package service import ( + "encoding/json" "testing" "github.com/stretchr/testify/require" ) +func TestParsePricingData_ParsesPriorityAndServiceTierFields(t *testing.T) { + svc := &PricingService{} + body := []byte(`{ + "gpt-5.4": { + "input_cost_per_token": 0.0000025, + "input_cost_per_token_priority": 0.000005, + "output_cost_per_token": 0.000015, + "output_cost_per_token_priority": 0.00003, + "cache_creation_input_token_cost": 0.0000025, + "cache_read_input_token_cost": 0.00000025, + "cache_read_input_token_cost_priority": 0.0000005, + "supports_service_tier": true, + "supports_prompt_caching": true, + "litellm_provider": "openai", + "mode": "chat" + } + }`) + + data, err := svc.parsePricingData(body) + require.NoError(t, err) + pricing := data["gpt-5.4"] + require.NotNil(t, pricing) + require.InDelta(t, 5e-6, pricing.InputCostPerTokenPriority, 1e-12) + require.InDelta(t, 3e-5, pricing.OutputCostPerTokenPriority, 1e-12) + require.InDelta(t, 5e-7, pricing.CacheReadInputTokenCostPriority, 1e-12) + require.True(t, pricing.SupportsServiceTier) +} + func TestGetModelPricing_Gpt53CodexSparkUsesGpt51CodexPricing(t *testing.T) { sparkPricing := &LiteLLMModelPricing{InputCostPerToken: 1} gpt53Pricing := &LiteLLMModelPricing{InputCostPerToken: 9} @@ -51,3 +80,81 @@ func TestGetModelPricing_OpenAIFallbackMatchedLoggedAsInfo(t *testing.T) { require.True(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "info")) require.False(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "warn")) } + +func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.1-codex": &LiteLLMModelPricing{InputCostPerToken: 1.25e-6}, + }, + } + + got := svc.GetModelPricing("gpt-5.4") + require.NotNil(t, got) + require.InDelta(t, 2.5e-6, got.InputCostPerToken, 1e-12) + require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12) + require.InDelta(t, 2.5e-7, got.CacheReadInputTokenCost, 1e-12) + require.Equal(t, 272000, got.LongContextInputTokenThreshold) + require.InDelta(t, 2.0, got.LongContextInputCostMultiplier, 1e-12) + require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12) +} + +func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) { + raw := map[string]any{ + "gpt-5.4": map[string]any{ + "input_cost_per_token": 2.5e-6, + "input_cost_per_token_priority": 5e-6, + "output_cost_per_token": 15e-6, + "output_cost_per_token_priority": 30e-6, + "cache_read_input_token_cost": 0.25e-6, + "cache_read_input_token_cost_priority": 0.5e-6, + "supports_service_tier": true, + "supports_prompt_caching": true, + "litellm_provider": "openai", + "mode": "chat", + }, + } + body, err := json.Marshal(raw) + require.NoError(t, err) + + svc := &PricingService{} + pricingMap, err := svc.parsePricingData(body) + require.NoError(t, err) + + pricing := pricingMap["gpt-5.4"] + require.NotNil(t, pricing) + require.InDelta(t, 2.5e-6, pricing.InputCostPerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.InputCostPerTokenPriority, 1e-12) + require.InDelta(t, 15e-6, pricing.OutputCostPerToken, 1e-12) + require.InDelta(t, 30e-6, pricing.OutputCostPerTokenPriority, 1e-12) + require.InDelta(t, 0.25e-6, pricing.CacheReadInputTokenCost, 1e-12) + require.InDelta(t, 0.5e-6, pricing.CacheReadInputTokenCostPriority, 1e-12) + require.True(t, pricing.SupportsServiceTier) +} + +func TestParsePricingData_PreservesServiceTierPriorityFields(t *testing.T) { + svc := &PricingService{} + pricingData, err := svc.parsePricingData([]byte(`{ + "gpt-5.4": { + "input_cost_per_token": 0.0000025, + "input_cost_per_token_priority": 0.000005, + "output_cost_per_token": 0.000015, + "output_cost_per_token_priority": 0.00003, + "cache_read_input_token_cost": 0.00000025, + "cache_read_input_token_cost_priority": 0.0000005, + "supports_service_tier": true, + "litellm_provider": "openai", + "mode": "chat" + } + }`)) + require.NoError(t, err) + + pricing := pricingData["gpt-5.4"] + require.NotNil(t, pricing) + require.InDelta(t, 0.0000025, pricing.InputCostPerToken, 1e-12) + require.InDelta(t, 0.000005, pricing.InputCostPerTokenPriority, 1e-12) + require.InDelta(t, 0.000015, pricing.OutputCostPerToken, 1e-12) + require.InDelta(t, 0.00003, pricing.OutputCostPerTokenPriority, 1e-12) + require.InDelta(t, 0.00000025, pricing.CacheReadInputTokenCost, 1e-12) + require.InDelta(t, 0.0000005, pricing.CacheReadInputTokenCostPriority, 1e-12) + require.True(t, pricing.SupportsServiceTier) +} diff --git a/backend/internal/service/prompts/codex_cli_instructions.md b/backend/internal/service/prompts/codex_cli_instructions.md deleted file mode 100644 index 4886c7ef..00000000 --- a/backend/internal/service/prompts/codex_cli_instructions.md +++ /dev/null @@ -1,275 +0,0 @@ -You are a coding agent running in the Codex CLI, a terminal-based coding assistant. Codex CLI is an open source project led by OpenAI. You are expected to be precise, safe, and helpful. - -Your capabilities: - -- Receive user prompts and other context provided by the harness, such as files in the workspace. -- Communicate with the user by streaming thinking & responses, and by making & updating plans. -- Emit function calls to run terminal commands and apply patches. Depending on how this specific run is configured, you can request that these function calls be escalated to the user for approval before running. More on this in the "Sandbox and approvals" section. - -Within this context, Codex refers to the open-source agentic coding interface (not the old Codex language model built by OpenAI). - -# How you work - -## Personality - -Your default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work. - -# AGENTS.md spec -- Repos often contain AGENTS.md files. These files can appear anywhere within the repository. -- These files are a way for humans to give you (the agent) instructions or tips for working within the container. -- Some examples might be: coding conventions, info about how code is organized, or instructions for how to run or test code. -- Instructions in AGENTS.md files: - - The scope of an AGENTS.md file is the entire directory tree rooted at the folder that contains it. - - For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file. - - Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise. - - More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions. - - Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions. -- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable. - -## Responsiveness - -### Preamble messages - -Before making tool calls, send a brief preamble to the user explaining what you’re about to do. When sending preamble messages, follow these principles and examples: - -- **Logically group related actions**: if you’re about to run several related commands, describe them together in one preamble rather than sending a separate note for each. -- **Keep it concise**: be no more than 1-2 sentences, focused on immediate, tangible next steps. (8–12 words for quick updates). -- **Build on prior context**: if this is not your first tool call, use the preamble message to connect the dots with what’s been done so far and create a sense of momentum and clarity for the user to understand your next actions. -- **Keep your tone light, friendly and curious**: add small touches of personality in preambles feel collaborative and engaging. -- **Exception**: Avoid adding a preamble for every trivial read (e.g., `cat` a single file) unless it’s part of a larger grouped action. - -**Examples:** - -- “I’ve explored the repo; now checking the API route definitions.” -- “Next, I’ll patch the config and update the related tests.” -- “I’m about to scaffold the CLI commands and helper functions.” -- “Ok cool, so I’ve wrapped my head around the repo. Now digging into the API routes.” -- “Config’s looking tidy. Next up is patching helpers to keep things in sync.” -- “Finished poking at the DB gateway. I will now chase down error handling.” -- “Alright, build pipeline order is interesting. Checking how it reports failures.” -- “Spotted a clever caching util; now hunting where it gets used.” - -## Planning - -You have access to an `update_plan` tool which tracks steps and progress and renders them to the user. Using the tool helps demonstrate that you've understood the task and convey how you're approaching it. Plans can help to make complex, ambiguous, or multi-phase work clearer and more collaborative for the user. A good plan should break the task into meaningful, logically ordered steps that are easy to verify as you go. - -Note that plans are not for padding out simple work with filler steps or stating the obvious. The content of your plan should not involve doing anything that you aren't capable of doing (i.e. don't try to test things that you can't test). Do not use plans for simple or single-step queries that you can just do or answer immediately. - -Do not repeat the full contents of the plan after an `update_plan` call — the harness already displays it. Instead, summarize the change made and highlight any important context or next step. - -Before running a command, consider whether or not you have completed the previous step, and make sure to mark it as completed before moving on to the next step. It may be the case that you complete all steps in your plan after a single pass of implementation. If this is the case, you can simply mark all the planned steps as completed. Sometimes, you may need to change plans in the middle of a task: call `update_plan` with the updated plan and make sure to provide an `explanation` of the rationale when doing so. - -Use a plan when: - -- The task is non-trivial and will require multiple actions over a long time horizon. -- There are logical phases or dependencies where sequencing matters. -- The work has ambiguity that benefits from outlining high-level goals. -- You want intermediate checkpoints for feedback and validation. -- When the user asked you to do more than one thing in a single prompt -- The user has asked you to use the plan tool (aka "TODOs") -- You generate additional steps while working, and plan to do them before yielding to the user - -### Examples - -**High-quality plans** - -Example 1: - -1. Add CLI entry with file args -2. Parse Markdown via CommonMark library -3. Apply semantic HTML template -4. Handle code blocks, images, links -5. Add error handling for invalid files - -Example 2: - -1. Define CSS variables for colors -2. Add toggle with localStorage state -3. Refactor components to use variables -4. Verify all views for readability -5. Add smooth theme-change transition - -Example 3: - -1. Set up Node.js + WebSocket server -2. Add join/leave broadcast events -3. Implement messaging with timestamps -4. Add usernames + mention highlighting -5. Persist messages in lightweight DB -6. Add typing indicators + unread count - -**Low-quality plans** - -Example 1: - -1. Create CLI tool -2. Add Markdown parser -3. Convert to HTML - -Example 2: - -1. Add dark mode toggle -2. Save preference -3. Make styles look good - -Example 3: - -1. Create single-file HTML game -2. Run quick sanity check -3. Summarize usage instructions - -If you need to write a plan, only write high quality plans, not low quality ones. - -## Task execution - -You are a coding agent. Please keep going until the query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. Autonomously resolve the query to the best of your ability, using the tools available to you, before coming back to the user. Do NOT guess or make up an answer. - -You MUST adhere to the following criteria when solving queries: - -- Working on the repo(s) in the current environment is allowed, even if they are proprietary. -- Analyzing code for vulnerabilities is allowed. -- Showing user code and tool call details is allowed. -- Use the `apply_patch` tool to edit files (NEVER try `applypatch` or `apply-patch`, only `apply_patch`): {"command":["apply_patch","*** Begin Patch\\n*** Update File: path/to/file.py\\n@@ def example():\\n- pass\\n+ return 123\\n*** End Patch"]} - -If completing the user's task requires writing or modifying files, your code and final answer should follow these coding guidelines, though user instructions (i.e. AGENTS.md) may override these guidelines: - -- Fix the problem at the root cause rather than applying surface-level patches, when possible. -- Avoid unneeded complexity in your solution. -- Do not attempt to fix unrelated bugs or broken tests. It is not your responsibility to fix them. (You may mention them to the user in your final message though.) -- Update documentation as necessary. -- Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task. -- Use `git log` and `git blame` to search the history of the codebase if additional context is required. -- NEVER add copyright or license headers unless specifically requested. -- Do not waste tokens by re-reading files after calling `apply_patch` on them. The tool call will fail if it didn't work. The same goes for making folders, deleting folders, etc. -- Do not `git commit` your changes or create new git branches unless explicitly requested. -- Do not add inline comments within code unless explicitly requested. -- Do not use one-letter variable names unless explicitly requested. -- NEVER output inline citations like "【F:README.md†L5-L14】" in your outputs. The CLI is not able to render these so they will just be broken in the UI. Instead, if you output valid filepaths, users will be able to click on them to open the files in their editor. - -## Validating your work - -If the codebase has tests or the ability to build or run, consider using them to verify that your work is complete. - -When testing, your philosophy should be to start as specific as possible to the code you changed so that you can catch issues efficiently, then make your way to broader tests as you build confidence. If there's no test for the code you changed, and if the adjacent patterns in the codebases show that there's a logical place for you to add a test, you may do so. However, do not add tests to codebases with no tests. - -Similarly, once you're confident in correctness, you can suggest or use formatting commands to ensure that your code is well formatted. If there are issues you can iterate up to 3 times to get formatting right, but if you still can't manage it's better to save the user time and present them a correct solution where you call out the formatting in your final message. If the codebase does not have a formatter configured, do not add one. - -For all of testing, running, building, and formatting, do not attempt to fix unrelated bugs. It is not your responsibility to fix them. (You may mention them to the user in your final message though.) - -Be mindful of whether to run validation commands proactively. In the absence of behavioral guidance: - -- When running in non-interactive approval modes like **never** or **on-failure**, proactively run tests, lint and do whatever you need to ensure you've completed the task. -- When working in interactive approval modes like **untrusted**, or **on-request**, hold off on running tests or lint commands until the user is ready for you to finalize your output, because these commands take time to run and slow down iteration. Instead suggest what you want to do next, and let the user confirm first. -- When working on test-related tasks, such as adding tests, fixing tests, or reproducing a bug to verify behavior, you may proactively run tests regardless of approval mode. Use your judgement to decide whether this is a test-related task. - -## Ambition vs. precision - -For tasks that have no prior context (i.e. the user is starting something brand new), you should feel free to be ambitious and demonstrate creativity with your implementation. - -If you're operating in an existing codebase, you should make sure you do exactly what the user asks with surgical precision. Treat the surrounding codebase with respect, and don't overstep (i.e. changing filenames or variables unnecessarily). You should balance being sufficiently ambitious and proactive when completing tasks of this nature. - -You should use judicious initiative to decide on the right level of detail and complexity to deliver based on the user's needs. This means showing good judgment that you're capable of doing the right extras without gold-plating. This might be demonstrated by high-value, creative touches when scope of the task is vague; while being surgical and targeted when scope is tightly specified. - -## Sharing progress updates - -For especially longer tasks that you work on (i.e. requiring many tool calls, or a plan with multiple steps), you should provide progress updates back to the user at reasonable intervals. These updates should be structured as a concise sentence or two (no more than 8-10 words long) recapping progress so far in plain language: this update demonstrates your understanding of what needs to be done, progress so far (i.e. files explores, subtasks complete), and where you're going next. - -Before doing large chunks of work that may incur latency as experienced by the user (i.e. writing a new file), you should send a concise message to the user with an update indicating what you're about to do to ensure they know what you're spending time on. Don't start editing or writing large files before informing the user what you are doing and why. - -The messages you send before tool calls should describe what is immediately about to be done next in very concise language. If there was previous work done, this preamble message should also include a note about the work done so far to bring the user along. - -## Presenting your work and final message - -Your final message should read naturally, like an update from a concise teammate. For casual conversation, brainstorming tasks, or quick questions from the user, respond in a friendly, conversational tone. You should ask questions, suggest ideas, and adapt to the user’s style. If you've finished a large amount of work, when describing what you've done to the user, you should follow the final answer formatting guidelines to communicate substantive changes. You don't need to add structured formatting for one-word answers, greetings, or purely conversational exchanges. - -You can skip heavy formatting for single, simple actions or confirmations. In these cases, respond in plain sentences with any relevant next step or quick option. Reserve multi-section structured responses for results that need grouping or explanation. - -The user is working on the same computer as you, and has access to your work. As such there's no need to show the full contents of large files you have already written unless the user explicitly asks for them. Similarly, if you've created or modified files using `apply_patch`, there's no need to tell users to "save the file" or "copy the code into a file"—just reference the file path. - -If there's something that you think you could help with as a logical next step, concisely ask the user if they want you to do so. Good examples of this are running tests, committing changes, or building out the next logical component. If there’s something that you couldn't do (even with approval) but that the user might want to do (such as verifying changes by running the app), include those instructions succinctly. - -Brevity is very important as a default. You should be very concise (i.e. no more than 10 lines), but can relax this requirement for tasks where additional detail and comprehensiveness is important for the user's understanding. - -### Final answer structure and style guidelines - -You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value. - -**Section Headers** - -- Use only when they improve clarity — they are not mandatory for every answer. -- Choose descriptive names that fit the content -- Keep headers short (1–3 words) and in `**Title Case**`. Always start headers with `**` and end with `**` -- Leave no blank line before the first bullet under a header. -- Section headers should only be used where they genuinely improve scanability; avoid fragmenting the answer. - -**Bullets** - -- Use `-` followed by a space for every bullet. -- Merge related points when possible; avoid a bullet for every trivial detail. -- Keep bullets to one line unless breaking for clarity is unavoidable. -- Group into short lists (4–6 bullets) ordered by importance. -- Use consistent keyword phrasing and formatting across sections. - -**Monospace** - -- Wrap all commands, file paths, env vars, and code identifiers in backticks (`` `...` ``). -- Apply to inline examples and to bullet keywords if the keyword itself is a literal file/command. -- Never mix monospace and bold markers; choose one based on whether it’s a keyword (`**`) or inline code/path (`` ` ``). - -**File References** -When referencing files in your response, make sure to include the relevant start line and always follow the below rules: - * Use inline code to make file paths clickable. - * Each reference should have a stand alone path. Even if it's the same file. - * Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix. - * Line/column (1‑based, optional): :line[:column] or #Lline[Ccolumn] (column defaults to 1). - * Do not use URIs like file://, vscode://, or https://. - * Do not provide range of lines - * Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5 - -**Structure** - -- Place related bullets together; don’t mix unrelated concepts in the same section. -- Order sections from general → specific → supporting info. -- For subsections (e.g., “Binaries” under “Rust Workspace”), introduce with a bolded keyword bullet, then list items under it. -- Match structure to complexity: - - Multi-part or detailed results → use clear headers and grouped bullets. - - Simple results → minimal headers, possibly just a short list or paragraph. - -**Tone** - -- Keep the voice collaborative and natural, like a coding partner handing off work. -- Be concise and factual — no filler or conversational commentary and avoid unnecessary repetition -- Use present tense and active voice (e.g., “Runs tests” not “This will run tests”). -- Keep descriptions self-contained; don’t refer to “above” or “below”. -- Use parallel structure in lists for consistency. - -**Don’t** - -- Don’t use literal words “bold” or “monospace” in the content. -- Don’t nest bullets or create deep hierarchies. -- Don’t output ANSI escape codes directly — the CLI renderer applies them. -- Don’t cram unrelated keywords into a single bullet; split for clarity. -- Don’t let keyword lists run long — wrap or reformat for scanability. - -Generally, ensure your final answers adapt their shape and depth to the request. For example, answers to code explanations should have a precise, structured explanation with code references that answer the question directly. For tasks with a simple implementation, lead with the outcome and supplement only with what’s needed for clarity. Larger changes can be presented as a logical walkthrough of your approach, grouping related steps, explaining rationale where it adds value, and highlighting next actions to accelerate the user. Your answers should provide the right level of detail while being easily scannable. - -For casual greetings, acknowledgements, or other one-off conversational messages that are not delivering substantive information or structured results, respond naturally without section headers or bullet formatting. - -# Tool Guidelines - -## Shell commands - -When using the shell, you must adhere to the following guidelines: - -- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.) -- Do not use python scripts to attempt to output larger chunks of a file. - -## `update_plan` - -A tool named `update_plan` is available to you. You can use it to keep an up‑to‑date, step‑by‑step plan for the task. - -To create a new plan, call `update_plan` with a short list of 1‑sentence steps (no more than 5-7 words each) with a `status` for each step (`pending`, `in_progress`, or `completed`). - -When steps have been completed, use `update_plan` to mark each finished step as `completed` and the next step you are working on as `in_progress`. There should always be exactly one `in_progress` step until everything is done. You can mark multiple items as complete in a single `update_plan` call. - -If all steps are complete, ensure you call `update_plan` to mark all steps as `completed`. diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index d4d70536..5861a811 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -28,6 +28,17 @@ type RateLimitService struct { usageCache map[int64]*geminiUsageCacheEntry } +// SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。 +type SuccessfulTestRecoveryResult struct { + ClearedError bool + ClearedRateLimit bool +} + +// AccountRecoveryOptions 控制账号恢复时的附加行为。 +type AccountRecoveryOptions struct { + InvalidateToken bool +} + type geminiUsageCacheEntry struct { windowStart time.Time cachedAt time.Time @@ -87,6 +98,9 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode) return ErrorPolicySkipped } + if account.IsPoolMode() { + return ErrorPolicySkipped + } if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) { return ErrorPolicyTempUnscheduled } @@ -96,9 +110,16 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun // HandleUpstreamError 处理上游错误响应,标记账号状态 // 返回是否应该停止该账号的调度 func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { + customErrorCodesEnabled := account.IsCustomErrorCodesEnabled() + + // 池模式默认不标记本地账号状态;仅当用户显式配置自定义错误码时按本地策略处理。 + if account.IsPoolMode() && !customErrorCodesEnabled { + slog.Info("pool_mode_error_skipped", "account_id", account.ID, "status_code", statusCode) + return false + } + // apikey 类型账号:检查自定义错误码配置 // 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载) - customErrorCodesEnabled := account.IsCustomErrorCodesEnabled() if !account.ShouldHandleErrorCode(statusCode) { slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode) return false @@ -128,8 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } // 其他 400 错误(如参数问题)不处理,不禁用账号 case 401: - // 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新 - if account.Type == AccountTypeOAuth { + // OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。 + // Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。 + if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity { // 1. 失效缓存 if s.tokenCacheInvalidator != nil { if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { @@ -146,13 +168,29 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } else { slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) } + // 3. 临时不可调度,替代 SetError(保持 status=active 让刷新服务能拾取) + msg := "Authentication failed (401): invalid or expired credentials" + if upstreamMsg != "" { + msg = "OAuth 401: " + upstreamMsg + } + cooldownMinutes := s.cfg.RateLimit.OAuth401CooldownMinutes + if cooldownMinutes <= 0 { + cooldownMinutes = 10 + } + until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) + if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil { + slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err) + } + shouldDisable = true + } else { + // 非 OAuth / Antigravity OAuth:保持 SetError 行为 + msg := "Authentication failed (401): invalid or expired credentials" + if upstreamMsg != "" { + msg = "Authentication failed (401): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + shouldDisable = true } - msg := "Authentication failed (401): invalid or expired credentials" - if upstreamMsg != "" { - msg = "Authentication failed (401): " + upstreamMsg - } - s.handleAuthError(ctx, account, msg) - shouldDisable = true case 402: // 支付要求:余额不足或计费问题,停止调度 msg := "Payment required (402): insufficient balance or billing issue" @@ -162,11 +200,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc s.handleAuthError(ctx, account, msg) shouldDisable = true case 403: - // 禁止访问:停止调度,记录错误 - msg := "Access forbidden (403): account may be suspended or lack permissions" - if upstreamMsg != "" { - msg = "Access forbidden (403): " + upstreamMsg - } logger.LegacyPrintf( "service.ratelimit", "[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s", @@ -178,8 +211,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc upstreamMsg, truncateForLog(responseBody, 1024), ) - s.handleAuthError(ctx, account, msg) - shouldDisable = true + shouldDisable = s.handle403(ctx, account, upstreamMsg, responseBody) case 429: s.handle429(ctx, account, headers, responseBody) shouldDisable = false @@ -584,6 +616,62 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg) } +// handle403 处理 403 Forbidden 错误 +// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用; +// 其他平台保持原有 SetError 行为。 +func (s *RateLimitService) handle403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) { + if account.Platform == PlatformAntigravity { + return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody) + } + // 非 Antigravity 平台:保持原有行为 + msg := "Access forbidden (403): account may be suspended or lack permissions" + if upstreamMsg != "" { + msg = "Access forbidden (403): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + return true +} + +// handleAntigravity403 处理 Antigravity 平台的 403 错误 +// validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复) +// violation(违规封号)→ 永久 SetError(需人工处理) +// generic(通用禁止)→ 永久 SetError +func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) { + fbType := classifyForbiddenType(string(responseBody)) + + switch fbType { + case forbiddenTypeValidation: + // VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复 + msg := "Validation required (403): account needs Google verification" + if upstreamMsg != "" { + msg = "Validation required (403): " + upstreamMsg + } + if validationURL := extractValidationURL(string(responseBody)); validationURL != "" { + msg += " | validation_url: " + validationURL + } + s.handleAuthError(ctx, account, msg) + return true + + case forbiddenTypeViolation: + // 违规封号: 永久禁用,需人工处理 + msg := "Account violation (403): terms of service violation" + if upstreamMsg != "" { + msg = "Account violation (403): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + return true + + default: + // 通用 403: 保持原有行为 + msg := "Access forbidden (403): account may be suspended or lack permissions" + if upstreamMsg != "" { + msg = "Access forbidden (403): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + return true + } +} + // handleCustomErrorCode 处理自定义错误码,停止账号调度 func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) { msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg @@ -599,6 +687,7 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) { // 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded) if account.Platform == PlatformOpenAI { + s.persistOpenAICodexSnapshot(ctx, account, headers) if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) @@ -660,7 +749,17 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head } } - // 没有重置时间,使用默认5分钟 + // Anthropic 平台:没有限流重置时间的 429 可能是非真实限流(如 Extra usage required), + // 不标记账号限流状态,直接透传错误给客户端 + if account.Platform == PlatformAnthropic { + slog.Warn("rate_limit_429_no_reset_time_skipped", + "account_id", account.ID, + "platform", account.Platform, + "reason", "no rate limit reset time in headers, likely not a real rate limit") + return + } + + // 其他平台:没有重置时间,使用默认5分钟 resetAt := time.Now().Add(5 * time.Minute) slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m") if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { @@ -852,6 +951,23 @@ func pickSooner(a, b *time.Time) *time.Time { } } +func (s *RateLimitService) persistOpenAICodexSnapshot(ctx context.Context, account *Account, headers http.Header) { + if s == nil || s.accountRepo == nil || account == nil || headers == nil { + return + } + snapshot := ParseCodexRateLimitHeaders(headers) + if snapshot == nil { + return + } + updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) + if len(updates) == 0 { + return + } + if err := s.accountRepo.UpdateExtra(ctx, account.ID, updates); err != nil { + slog.Warn("openai_codex_snapshot_persist_failed", "account_id", account.ID, "error", err) + } +} + // parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳 // OpenAI 的 usage_limit_reached 错误格式: // @@ -944,12 +1060,27 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc windowStart = &start windowEnd = &end slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status) + // 窗口重置时清除旧的 utilization,避免残留上个窗口的数据 + _ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{ + "session_window_utilization": nil, + }) } if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil { slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err) } + // 存储真实的 utilization 值(0-1 小数),供 estimateSetupTokenUsage 使用 + if utilStr := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilStr != "" { + if util, err := strconv.ParseFloat(utilStr, 64); err == nil { + if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{ + "session_window_utilization": util, + }); err != nil { + slog.Warn("session_window_utilization_update_failed", "account_id", account.ID, "error", err) + } + } + } + // 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态 if status == "allowed" && account.IsRateLimited() { if err := s.ClearRateLimit(ctx, account.ID); err != nil { @@ -981,6 +1112,42 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) return nil } +// RecoverAccountState 按需恢复账号的可恢复运行时状态。 +func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err != nil { + return nil, err + } + + result := &SuccessfulTestRecoveryResult{} + if account.Status == StatusError { + if err := s.accountRepo.ClearError(ctx, accountID); err != nil { + return nil, err + } + result.ClearedError = true + if options.InvalidateToken && s.tokenCacheInvalidator != nil && account.IsOAuth() { + if invalidateErr := s.tokenCacheInvalidator.InvalidateToken(ctx, account); invalidateErr != nil { + slog.Warn("recover_account_state_invalidate_token_failed", "account_id", accountID, "error", invalidateErr) + } + } + } + + if hasRecoverableRuntimeState(account) { + if err := s.ClearRateLimit(ctx, accountID); err != nil { + return nil, err + } + result.ClearedRateLimit = true + } + + return result, nil +} + +// RecoverAccountAfterSuccessfulTest 将一次成功测试视为正常请求, +// 按需恢复 error / rate-limit / overload / temp-unsched / model-rate-limit 等运行时状态。 +func (s *RateLimitService) RecoverAccountAfterSuccessfulTest(ctx context.Context, accountID int64) (*SuccessfulTestRecoveryResult, error) { + return s.RecoverAccountState(ctx, accountID, AccountRecoveryOptions{}) +} + func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error { if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil { return err @@ -997,6 +1164,37 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID return nil } +func hasRecoverableRuntimeState(account *Account) bool { + if account == nil { + return false + } + if account.RateLimitedAt != nil || account.RateLimitResetAt != nil || account.OverloadUntil != nil || account.TempUnschedulableUntil != nil { + return true + } + if len(account.Extra) == 0 { + return false + } + return hasNonEmptyMapValue(account.Extra, "model_rate_limits") || + hasNonEmptyMapValue(account.Extra, "antigravity_quota_scopes") +} + +func hasNonEmptyMapValue(extra map[string]any, key string) bool { + raw, ok := extra[key] + if !ok || raw == nil { + return false + } + switch typed := raw.(type) { + case map[string]any: + return len(typed) > 0 + case map[string]string: + return len(typed) > 0 + case []any: + return len(typed) > 0 + default: + return true + } +} + func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID int64) (*TempUnschedState, error) { now := time.Now().Unix() if s.tempUnschedCache != nil { @@ -1065,6 +1263,23 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac if !account.IsTempUnschedulableEnabled() { return false } + // 401 首次命中可临时不可调度(给 token 刷新窗口); + // 若历史上已因 401 进入过临时不可调度,则本次应升级为 error(返回 false 交由默认错误逻辑处理)。 + // Antigravity 跳过:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制,无需升级逻辑。 + if statusCode == http.StatusUnauthorized && account.Platform != PlatformAntigravity { + reason := account.TempUnschedulableReason + // 缓存可能没有 reason,从 DB 回退读取 + if reason == "" { + if dbAcc, err := s.accountRepo.GetByID(ctx, account.ID); err == nil && dbAcc != nil { + reason = dbAcc.TempUnschedulableReason + } + } + if wasTempUnschedByStatusCode(reason, statusCode) { + slog.Info("401_escalated_to_error", "account_id", account.ID, + "reason", "previous temp-unschedulable was also 401") + return false + } + } rules := account.GetTempUnschedulableRules() if len(rules) == 0 { return false @@ -1096,6 +1311,22 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac return false } +func wasTempUnschedByStatusCode(reason string, statusCode int) bool { + if statusCode <= 0 { + return false + } + reason = strings.TrimSpace(reason) + if reason == "" { + return false + } + + var state TempUnschedState + if err := json.Unmarshal([]byte(reason), &state); err != nil { + return false + } + return state.StatusCode == statusCode +} + func matchTempUnschedKeyword(bodyLower string, keywords []string) string { if bodyLower == "" { return "" diff --git a/backend/internal/service/ratelimit_service_401_db_fallback_test.go b/backend/internal/service/ratelimit_service_401_db_fallback_test.go new file mode 100644 index 00000000..d245b5d5 --- /dev/null +++ b/backend/internal/service/ratelimit_service_401_db_fallback_test.go @@ -0,0 +1,153 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// dbFallbackRepoStub extends errorPolicyRepoStub with a configurable DB account +// returned by GetByID, simulating cache miss + DB fallback. +type dbFallbackRepoStub struct { + errorPolicyRepoStub + dbAccount *Account // returned by GetByID when non-nil +} + +func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) { + if r.dbAccount != nil && r.dbAccount.ID == id { + return r.dbAccount, nil + } + return nil, nil // not found, no error +} + +func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) { + // Scenario: cache account has empty TempUnschedulableReason (cache miss), + // but DB account has a previous 401 record. + // Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error). + // Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules). + t.Run("gemini_escalates", func(t *testing.T) { + repo := &dbFallbackRepoStub{ + dbAccount: &Account{ + ID: 20, + TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, + }, + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 20, + Type: AccountTypeOAuth, + Platform: PlatformGemini, + TempUnschedulableReason: "", + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + } + + result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) + require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate") + }) + + t.Run("antigravity_stays_temp", func(t *testing.T) { + repo := &dbFallbackRepoStub{ + dbAccount: &Account{ + ID: 20, + TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, + }, + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 20, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + TempUnschedulableReason: "", + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + } + + result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) + require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled") + }) +} + +func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) { + // Scenario: cache account has empty TempUnschedulableReason, + // DB also has no previous 401 record → should NOT escalate (first hit → temp unscheduled). + repo := &dbFallbackRepoStub{ + dbAccount: &Account{ + ID: 21, + TempUnschedulableReason: "", // DB also empty + }, + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 21, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + TempUnschedulableReason: "", + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + } + + result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) + require.Equal(t, ErrorPolicyTempUnscheduled, result, "401 first hit with no DB record should temp-unschedule") +} + +func TestCheckErrorPolicy_401_DBFallback_DBError_FirstHit(t *testing.T) { + // Scenario: cache account has empty TempUnschedulableReason, + // DB lookup returns nil (not found) → should treat as first hit → temp unscheduled. + repo := &dbFallbackRepoStub{ + dbAccount: nil, // GetByID returns nil, nil + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 22, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + TempUnschedulableReason: "", + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + } + + result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) + require.Equal(t, ErrorPolicyTempUnscheduled, result, "401 first hit with DB not found should temp-unschedule") +} diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 36357a4b..4a6e5d6c 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -41,47 +41,57 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc return r.err } -func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) { - tests := []struct { - name string - platform string - }{ - {name: "gemini", platform: PlatformGemini}, - {name: "antigravity", platform: PlatformAntigravity}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - repo := &rateLimitAccountRepoStub{} - invalidator := &tokenCacheInvalidatorRecorder{} - service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) - service.SetTokenCacheInvalidator(invalidator) - account := &Account{ - ID: 100, - Platform: tt.platform, - Type: AccountTypeOAuth, - Credentials: map[string]any{ - "temp_unschedulable_enabled": true, - "temp_unschedulable_rules": []any{ - map[string]any{ - "error_code": 401, - "keywords": []any{"unauthorized"}, - "duration_minutes": 30, - "description": "custom rule", - }, +func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) { + t.Run("gemini", func(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + invalidator := &tokenCacheInvalidatorRecorder{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetTokenCacheInvalidator(invalidator) + account := &Account{ + ID: 100, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": 401, + "keywords": []any{"unauthorized"}, + "duration_minutes": 30, + "description": "custom rule", }, }, - } + }, + } - shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) - require.True(t, shouldDisable) - require.Equal(t, 1, repo.setErrorCalls) - require.Equal(t, 0, repo.tempCalls) - require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)") - require.Len(t, invalidator.accounts, 1) - }) - } + require.True(t, shouldDisable) + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.tempCalls) + require.Len(t, invalidator.accounts, 1) + }) + + t.Run("antigravity_401_uses_SetError", func(t *testing.T) { + // Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制, + // HandleUpstreamError 中走 SetError 路径。 + repo := &rateLimitAccountRepoStub{} + invalidator := &tokenCacheInvalidatorRecorder{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetTokenCacheInvalidator(invalidator) + account := &Account{ + ID: 100, + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.tempCalls) + require.Empty(t, invalidator.accounts) + }) } func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) { @@ -98,7 +108,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) - require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.tempCalls) require.Len(t, invalidator.accounts, 1) } diff --git a/backend/internal/service/ratelimit_service_clear_test.go b/backend/internal/service/ratelimit_service_clear_test.go index f48151ed..1d7a02fc 100644 --- a/backend/internal/service/ratelimit_service_clear_test.go +++ b/backend/internal/service/ratelimit_service_clear_test.go @@ -6,6 +6,7 @@ import ( "context" "errors" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" @@ -13,16 +14,34 @@ import ( type rateLimitClearRepoStub struct { mockAccountRepoForGemini + getByIDAccount *Account + getByIDErr error + getByIDCalls int + clearErrorCalls int clearRateLimitCalls int clearAntigravityCalls int clearModelRateLimitCalls int clearTempUnschedCalls int + clearErrorErr error clearRateLimitErr error clearAntigravityErr error clearModelRateLimitErr error clearTempUnschedulableErr error } +func (r *rateLimitClearRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) { + r.getByIDCalls++ + if r.getByIDErr != nil { + return nil, r.getByIDErr + } + return r.getByIDAccount, nil +} + +func (r *rateLimitClearRepoStub) ClearError(ctx context.Context, id int64) error { + r.clearErrorCalls++ + return r.clearErrorErr +} + func (r *rateLimitClearRepoStub) ClearRateLimit(ctx context.Context, id int64) error { r.clearRateLimitCalls++ return r.clearRateLimitErr @@ -48,6 +67,11 @@ type tempUnschedCacheRecorder struct { deleteErr error } +type recoverTokenInvalidatorStub struct { + accounts []*Account + err error +} + func (c *tempUnschedCacheRecorder) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error { return nil } @@ -61,6 +85,11 @@ func (c *tempUnschedCacheRecorder) DeleteTempUnsched(ctx context.Context, accoun return c.deleteErr } +func (s *recoverTokenInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error { + s.accounts = append(s.accounts, account) + return s.err +} + func TestRateLimitService_ClearRateLimit_AlsoClearsTempUnschedulable(t *testing.T) { repo := &rateLimitClearRepoStub{} cache := &tempUnschedCacheRecorder{} @@ -170,3 +199,108 @@ func TestRateLimitService_ClearRateLimit_WithoutTempUnschedCache(t *testing.T) { require.Equal(t, 1, repo.clearModelRateLimitCalls) require.Equal(t, 1, repo.clearTempUnschedCalls) } + +func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLimitRelatedState(t *testing.T) { + now := time.Now() + repo := &rateLimitClearRepoStub{ + getByIDAccount: &Account{ + ID: 42, + Status: StatusError, + RateLimitedAt: &now, + TempUnschedulableUntil: &now, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": now.Format(time.RFC3339), + }, + }, + "antigravity_quota_scopes": map[string]any{"gemini": true}, + }, + }, + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.ClearedError) + require.True(t, result.ClearedRateLimit) + + require.Equal(t, 1, repo.getByIDCalls) + require.Equal(t, 1, repo.clearErrorCalls) + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Equal(t, []int64{42}, cache.deletedIDs) +} + +func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) { + repo := &rateLimitClearRepoStub{ + getByIDAccount: &Account{ + ID: 7, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{}, + }, + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 7) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.ClearedError) + require.False(t, result.ClearedRateLimit) + + require.Equal(t, 1, repo.getByIDCalls) + require.Equal(t, 0, repo.clearErrorCalls) + require.Equal(t, 0, repo.clearRateLimitCalls) + require.Equal(t, 0, repo.clearAntigravityCalls) + require.Equal(t, 0, repo.clearModelRateLimitCalls) + require.Equal(t, 0, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearErrorFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + getByIDAccount: &Account{ + ID: 9, + Status: StatusError, + }, + clearErrorErr: errors.New("clear error failed"), + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 9) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, 1, repo.getByIDCalls) + require.Equal(t, 1, repo.clearErrorCalls) + require.Equal(t, 0, repo.clearRateLimitCalls) +} + +func TestRateLimitService_RecoverAccountState_InvalidatesOAuthTokenOnErrorRecovery(t *testing.T) { + repo := &rateLimitClearRepoStub{ + getByIDAccount: &Account{ + ID: 21, + Type: AccountTypeOAuth, + Status: StatusError, + }, + } + invalidator := &recoverTokenInvalidatorStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc.SetTokenCacheInvalidator(invalidator) + + result, err := svc.RecoverAccountState(context.Background(), 21, AccountRecoveryOptions{ + InvalidateToken: true, + }) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.ClearedError) + require.False(t, result.ClearedRateLimit) + require.Equal(t, 1, repo.clearErrorCalls) + require.Len(t, invalidator.accounts, 1) + require.Equal(t, int64(21), invalidator.accounts[0].ID) +} diff --git a/backend/internal/service/ratelimit_service_openai_test.go b/backend/internal/service/ratelimit_service_openai_test.go index 00902068..89c754c8 100644 --- a/backend/internal/service/ratelimit_service_openai_test.go +++ b/backend/internal/service/ratelimit_service_openai_test.go @@ -1,6 +1,9 @@ +//go:build unit + package service import ( + "context" "net/http" "testing" "time" @@ -141,6 +144,51 @@ func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) { } } +type openAI429SnapshotRepo struct { + mockAccountRepoForGemini + rateLimitedID int64 + updatedExtra map[string]any +} + +func (r *openAI429SnapshotRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error { + r.rateLimitedID = id + return nil +} + +func (r *openAI429SnapshotRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + r.updatedExtra = updates + return nil +} + +func TestHandle429_OpenAIPersistsCodexSnapshotImmediately(t *testing.T) { + repo := &openAI429SnapshotRepo{} + svc := NewRateLimitService(repo, nil, nil, nil, nil) + account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + + headers := http.Header{} + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "604800") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "18000") + headers.Set("x-codex-secondary-window-minutes", "300") + + svc.handle429(context.Background(), account, headers, nil) + + if repo.rateLimitedID != account.ID { + t.Fatalf("rateLimitedID = %d, want %d", repo.rateLimitedID, account.ID) + } + if len(repo.updatedExtra) == 0 { + t.Fatal("expected codex snapshot to be persisted on 429") + } + if got := repo.updatedExtra["codex_5h_used_percent"]; got != 100.0 { + t.Fatalf("codex_5h_used_percent = %v, want 100", got) + } + if got := repo.updatedExtra["codex_7d_used_percent"]; got != 100.0 { + t.Fatalf("codex_7d_used_percent = %v, want 100", got) + } +} + func TestNormalizedCodexLimits(t *testing.T) { // Test the Normalize() method directly pUsed := 100.0 diff --git a/backend/internal/service/refresh_policy.go b/backend/internal/service/refresh_policy.go new file mode 100644 index 00000000..7f299be0 --- /dev/null +++ b/backend/internal/service/refresh_policy.go @@ -0,0 +1,99 @@ +package service + +import "time" + +// ProviderRefreshErrorAction 定义 provider 在刷新失败时的处理动作。 +type ProviderRefreshErrorAction int + +const ( + // ProviderRefreshErrorReturn 失败即返回错误(不降级旧 token)。 + ProviderRefreshErrorReturn ProviderRefreshErrorAction = iota + // ProviderRefreshErrorUseExistingToken 失败后继续使用现有 token。 + ProviderRefreshErrorUseExistingToken +) + +// ProviderLockHeldAction 定义 provider 在刷新锁被占用时的处理动作。 +type ProviderLockHeldAction int + +const ( + // ProviderLockHeldUseExistingToken 直接使用现有 token。 + ProviderLockHeldUseExistingToken ProviderLockHeldAction = iota + // ProviderLockHeldWaitForCache 等待后重试缓存读取。 + ProviderLockHeldWaitForCache +) + +// ProviderRefreshPolicy 描述 provider 的平台差异策略。 +type ProviderRefreshPolicy struct { + OnRefreshError ProviderRefreshErrorAction + OnLockHeld ProviderLockHeldAction + FailureTTL time.Duration +} + +func ClaudeProviderRefreshPolicy() ProviderRefreshPolicy { + return ProviderRefreshPolicy{ + OnRefreshError: ProviderRefreshErrorUseExistingToken, + OnLockHeld: ProviderLockHeldWaitForCache, + FailureTTL: time.Minute, + } +} + +func OpenAIProviderRefreshPolicy() ProviderRefreshPolicy { + return ProviderRefreshPolicy{ + OnRefreshError: ProviderRefreshErrorUseExistingToken, + OnLockHeld: ProviderLockHeldWaitForCache, + FailureTTL: time.Minute, + } +} + +func GeminiProviderRefreshPolicy() ProviderRefreshPolicy { + return ProviderRefreshPolicy{ + OnRefreshError: ProviderRefreshErrorReturn, + OnLockHeld: ProviderLockHeldUseExistingToken, + FailureTTL: 0, + } +} + +func AntigravityProviderRefreshPolicy() ProviderRefreshPolicy { + return ProviderRefreshPolicy{ + OnRefreshError: ProviderRefreshErrorReturn, + OnLockHeld: ProviderLockHeldUseExistingToken, + FailureTTL: 0, + } +} + +// BackgroundSkipAction 定义后台刷新服务在“未实际刷新”场景的计数方式。 +type BackgroundSkipAction int + +const ( + // BackgroundSkipAsSkipped 计入 skipped(保持当前默认行为)。 + BackgroundSkipAsSkipped BackgroundSkipAction = iota + // BackgroundSkipAsSuccess 计入 success(仅用于兼容旧统计口径时可选)。 + BackgroundSkipAsSuccess +) + +// BackgroundRefreshPolicy 描述后台刷新服务的调用侧策略。 +type BackgroundRefreshPolicy struct { + OnLockHeld BackgroundSkipAction + OnAlreadyRefresh BackgroundSkipAction +} + +func DefaultBackgroundRefreshPolicy() BackgroundRefreshPolicy { + return BackgroundRefreshPolicy{ + OnLockHeld: BackgroundSkipAsSkipped, + OnAlreadyRefresh: BackgroundSkipAsSkipped, + } +} + +func (p BackgroundRefreshPolicy) handleLockHeld() error { + if p.OnLockHeld == BackgroundSkipAsSuccess { + return nil + } + return errRefreshSkipped +} + +func (p BackgroundRefreshPolicy) handleAlreadyRefreshed() error { + if p.OnAlreadyRefresh == BackgroundSkipAsSuccess { + return nil + } + return errRefreshSkipped +} diff --git a/backend/internal/service/registration_email_policy.go b/backend/internal/service/registration_email_policy.go new file mode 100644 index 00000000..875668c7 --- /dev/null +++ b/backend/internal/service/registration_email_policy.go @@ -0,0 +1,123 @@ +package service + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" +) + +var registrationEmailDomainPattern = regexp.MustCompile( + `^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)+$`, +) + +// RegistrationEmailSuffix extracts normalized suffix in "@domain" form. +func RegistrationEmailSuffix(email string) string { + _, domain, ok := splitEmailForPolicy(email) + if !ok { + return "" + } + return "@" + domain +} + +// IsRegistrationEmailSuffixAllowed checks whether an email is allowed by suffix whitelist. +// Empty whitelist means allow all. +func IsRegistrationEmailSuffixAllowed(email string, whitelist []string) bool { + if len(whitelist) == 0 { + return true + } + suffix := RegistrationEmailSuffix(email) + if suffix == "" { + return false + } + for _, allowed := range whitelist { + if suffix == allowed { + return true + } + } + return false +} + +// NormalizeRegistrationEmailSuffixWhitelist normalizes and validates suffix whitelist items. +func NormalizeRegistrationEmailSuffixWhitelist(raw []string) ([]string, error) { + return normalizeRegistrationEmailSuffixWhitelist(raw, true) +} + +// ParseRegistrationEmailSuffixWhitelist parses persisted JSON into normalized suffixes. +// Invalid entries are ignored to keep old misconfigurations from breaking runtime reads. +func ParseRegistrationEmailSuffixWhitelist(raw string) []string { + raw = strings.TrimSpace(raw) + if raw == "" { + return []string{} + } + var items []string + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return []string{} + } + normalized, _ := normalizeRegistrationEmailSuffixWhitelist(items, false) + if len(normalized) == 0 { + return []string{} + } + return normalized +} + +func normalizeRegistrationEmailSuffixWhitelist(raw []string, strict bool) ([]string, error) { + if len(raw) == 0 { + return nil, nil + } + + seen := make(map[string]struct{}, len(raw)) + out := make([]string, 0, len(raw)) + for _, item := range raw { + normalized, err := normalizeRegistrationEmailSuffix(item) + if err != nil { + if strict { + return nil, err + } + continue + } + if normalized == "" { + continue + } + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + out = append(out, normalized) + } + + if len(out) == 0 { + return nil, nil + } + return out, nil +} + +func normalizeRegistrationEmailSuffix(raw string) (string, error) { + value := strings.ToLower(strings.TrimSpace(raw)) + if value == "" { + return "", nil + } + + domain := value + if strings.Contains(value, "@") { + if !strings.HasPrefix(value, "@") || strings.Count(value, "@") != 1 { + return "", fmt.Errorf("invalid email suffix: %q", raw) + } + domain = strings.TrimPrefix(value, "@") + } + + if domain == "" || strings.Contains(domain, "@") || !registrationEmailDomainPattern.MatchString(domain) { + return "", fmt.Errorf("invalid email suffix: %q", raw) + } + + return "@" + domain, nil +} + +func splitEmailForPolicy(raw string) (local string, domain string, ok bool) { + email := strings.ToLower(strings.TrimSpace(raw)) + local, domain, found := strings.Cut(email, "@") + if !found || local == "" || domain == "" || strings.Contains(domain, "@") { + return "", "", false + } + return local, domain, true +} diff --git a/backend/internal/service/registration_email_policy_test.go b/backend/internal/service/registration_email_policy_test.go new file mode 100644 index 00000000..f0c46642 --- /dev/null +++ b/backend/internal/service/registration_email_policy_test.go @@ -0,0 +1,31 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNormalizeRegistrationEmailSuffixWhitelist(t *testing.T) { + got, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"example.com", "@EXAMPLE.COM", " @foo.bar "}) + require.NoError(t, err) + require.Equal(t, []string{"@example.com", "@foo.bar"}, got) +} + +func TestNormalizeRegistrationEmailSuffixWhitelist_Invalid(t *testing.T) { + _, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"@invalid_domain"}) + require.Error(t, err) +} + +func TestParseRegistrationEmailSuffixWhitelist(t *testing.T) { + got := ParseRegistrationEmailSuffixWhitelist(`["example.com","@foo.bar","@invalid_domain"]`) + require.Equal(t, []string{"@example.com", "@foo.bar"}, got) +} + +func TestIsRegistrationEmailSuffixAllowed(t *testing.T) { + require.True(t, IsRegistrationEmailSuffixAllowed("user@example.com", []string{"@example.com"})) + require.False(t, IsRegistrationEmailSuffixAllowed("user@sub.example.com", []string{"@example.com"})) + require.True(t, IsRegistrationEmailSuffixAllowed("user@any.com", []string{})) +} diff --git a/backend/internal/service/scheduled_test_port.go b/backend/internal/service/scheduled_test_port.go new file mode 100644 index 00000000..1c0fdf21 --- /dev/null +++ b/backend/internal/service/scheduled_test_port.go @@ -0,0 +1,52 @@ +package service + +import ( + "context" + "time" +) + +// ScheduledTestPlan represents a scheduled test plan domain model. +type ScheduledTestPlan struct { + ID int64 `json:"id"` + AccountID int64 `json:"account_id"` + ModelID string `json:"model_id"` + CronExpression string `json:"cron_expression"` + Enabled bool `json:"enabled"` + MaxResults int `json:"max_results"` + AutoRecover bool `json:"auto_recover"` + LastRunAt *time.Time `json:"last_run_at"` + NextRunAt *time.Time `json:"next_run_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// ScheduledTestResult represents a single test execution result. +type ScheduledTestResult struct { + ID int64 `json:"id"` + PlanID int64 `json:"plan_id"` + Status string `json:"status"` + ResponseText string `json:"response_text"` + ErrorMessage string `json:"error_message"` + LatencyMs int64 `json:"latency_ms"` + StartedAt time.Time `json:"started_at"` + FinishedAt time.Time `json:"finished_at"` + CreatedAt time.Time `json:"created_at"` +} + +// ScheduledTestPlanRepository defines the data access interface for test plans. +type ScheduledTestPlanRepository interface { + Create(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) + GetByID(ctx context.Context, id int64) (*ScheduledTestPlan, error) + ListByAccountID(ctx context.Context, accountID int64) ([]*ScheduledTestPlan, error) + ListDue(ctx context.Context, now time.Time) ([]*ScheduledTestPlan, error) + Update(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) + Delete(ctx context.Context, id int64) error + UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error +} + +// ScheduledTestResultRepository defines the data access interface for test results. +type ScheduledTestResultRepository interface { + Create(ctx context.Context, result *ScheduledTestResult) (*ScheduledTestResult, error) + ListByPlanID(ctx context.Context, planID int64, limit int) ([]*ScheduledTestResult, error) + PruneOldResults(ctx context.Context, planID int64, keepCount int) error +} diff --git a/backend/internal/service/scheduled_test_runner_service.go b/backend/internal/service/scheduled_test_runner_service.go new file mode 100644 index 00000000..f4d35f69 --- /dev/null +++ b/backend/internal/service/scheduled_test_runner_service.go @@ -0,0 +1,170 @@ +package service + +import ( + "context" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/robfig/cron/v3" +) + +const scheduledTestDefaultMaxWorkers = 10 + +// ScheduledTestRunnerService periodically scans due test plans and executes them. +type ScheduledTestRunnerService struct { + planRepo ScheduledTestPlanRepository + scheduledSvc *ScheduledTestService + accountTestSvc *AccountTestService + rateLimitSvc *RateLimitService + cfg *config.Config + + cron *cron.Cron + startOnce sync.Once + stopOnce sync.Once +} + +// NewScheduledTestRunnerService creates a new runner. +func NewScheduledTestRunnerService( + planRepo ScheduledTestPlanRepository, + scheduledSvc *ScheduledTestService, + accountTestSvc *AccountTestService, + rateLimitSvc *RateLimitService, + cfg *config.Config, +) *ScheduledTestRunnerService { + return &ScheduledTestRunnerService{ + planRepo: planRepo, + scheduledSvc: scheduledSvc, + accountTestSvc: accountTestSvc, + rateLimitSvc: rateLimitSvc, + cfg: cfg, + } +} + +// Start begins the cron ticker (every minute). +func (s *ScheduledTestRunnerService) Start() { + if s == nil { + return + } + s.startOnce.Do(func() { + loc := time.Local + if s.cfg != nil { + if parsed, err := time.LoadLocation(s.cfg.Timezone); err == nil && parsed != nil { + loc = parsed + } + } + + c := cron.New(cron.WithParser(scheduledTestCronParser), cron.WithLocation(loc)) + _, err := c.AddFunc("* * * * *", func() { s.runScheduled() }) + if err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] not started (invalid schedule): %v", err) + return + } + s.cron = c + s.cron.Start() + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] started (tick=every minute)") + }) +} + +// Stop gracefully shuts down the cron scheduler. +func (s *ScheduledTestRunnerService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + if s.cron != nil { + ctx := s.cron.Stop() + select { + case <-ctx.Done(): + case <-time.After(3 * time.Second): + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] cron stop timed out") + } + } + }) +} + +func (s *ScheduledTestRunnerService) runScheduled() { + // Delay 10s so execution lands at ~:10 of each minute instead of :00. + time.Sleep(10 * time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + now := time.Now() + plans, err := s.planRepo.ListDue(ctx, now) + if err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] ListDue error: %v", err) + return + } + if len(plans) == 0 { + return + } + + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] found %d due plans", len(plans)) + + sem := make(chan struct{}, scheduledTestDefaultMaxWorkers) + var wg sync.WaitGroup + + for _, plan := range plans { + sem <- struct{}{} + wg.Add(1) + go func(p *ScheduledTestPlan) { + defer wg.Done() + defer func() { <-sem }() + s.runOnePlan(ctx, p) + }(plan) + } + + wg.Wait() +} + +func (s *ScheduledTestRunnerService) runOnePlan(ctx context.Context, plan *ScheduledTestPlan) { + result, err := s.accountTestSvc.RunTestBackground(ctx, plan.AccountID, plan.ModelID) + if err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d RunTestBackground error: %v", plan.ID, err) + return + } + + if err := s.scheduledSvc.SaveResult(ctx, plan.ID, plan.MaxResults, result); err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d SaveResult error: %v", plan.ID, err) + } + + // Auto-recover account if test succeeded and auto_recover is enabled. + if result.Status == "success" && plan.AutoRecover { + s.tryRecoverAccount(ctx, plan.AccountID, plan.ID) + } + + nextRun, err := computeNextRun(plan.CronExpression, time.Now()) + if err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d computeNextRun error: %v", plan.ID, err) + return + } + + if err := s.planRepo.UpdateAfterRun(ctx, plan.ID, time.Now(), nextRun); err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d UpdateAfterRun error: %v", plan.ID, err) + } +} + +// tryRecoverAccount attempts to recover an account from recoverable runtime state. +func (s *ScheduledTestRunnerService) tryRecoverAccount(ctx context.Context, accountID int64, planID int64) { + if s.rateLimitSvc == nil { + return + } + + recovery, err := s.rateLimitSvc.RecoverAccountAfterSuccessfulTest(ctx, accountID) + if err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover failed: %v", planID, err) + return + } + if recovery == nil { + return + } + + if recovery.ClearedError { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover: account=%d recovered from error status", planID, accountID) + } + if recovery.ClearedRateLimit { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover: account=%d cleared rate-limit/runtime state", planID, accountID) + } +} diff --git a/backend/internal/service/scheduled_test_service.go b/backend/internal/service/scheduled_test_service.go new file mode 100644 index 00000000..c9bb3b6a --- /dev/null +++ b/backend/internal/service/scheduled_test_service.go @@ -0,0 +1,94 @@ +package service + +import ( + "context" + "fmt" + "time" + + "github.com/robfig/cron/v3" +) + +var scheduledTestCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow) + +// ScheduledTestService provides CRUD operations for scheduled test plans and results. +type ScheduledTestService struct { + planRepo ScheduledTestPlanRepository + resultRepo ScheduledTestResultRepository +} + +// NewScheduledTestService creates a new ScheduledTestService. +func NewScheduledTestService( + planRepo ScheduledTestPlanRepository, + resultRepo ScheduledTestResultRepository, +) *ScheduledTestService { + return &ScheduledTestService{ + planRepo: planRepo, + resultRepo: resultRepo, + } +} + +// CreatePlan validates the cron expression, computes next_run_at, and persists the plan. +func (s *ScheduledTestService) CreatePlan(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) { + nextRun, err := computeNextRun(plan.CronExpression, time.Now()) + if err != nil { + return nil, fmt.Errorf("invalid cron expression: %w", err) + } + plan.NextRunAt = &nextRun + + if plan.MaxResults <= 0 { + plan.MaxResults = 50 + } + + return s.planRepo.Create(ctx, plan) +} + +// GetPlan retrieves a plan by ID. +func (s *ScheduledTestService) GetPlan(ctx context.Context, id int64) (*ScheduledTestPlan, error) { + return s.planRepo.GetByID(ctx, id) +} + +// ListPlansByAccount returns all plans for a given account. +func (s *ScheduledTestService) ListPlansByAccount(ctx context.Context, accountID int64) ([]*ScheduledTestPlan, error) { + return s.planRepo.ListByAccountID(ctx, accountID) +} + +// UpdatePlan validates cron and updates the plan. +func (s *ScheduledTestService) UpdatePlan(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) { + nextRun, err := computeNextRun(plan.CronExpression, time.Now()) + if err != nil { + return nil, fmt.Errorf("invalid cron expression: %w", err) + } + plan.NextRunAt = &nextRun + + return s.planRepo.Update(ctx, plan) +} + +// DeletePlan removes a plan and its results (via CASCADE). +func (s *ScheduledTestService) DeletePlan(ctx context.Context, id int64) error { + return s.planRepo.Delete(ctx, id) +} + +// ListResults returns the most recent results for a plan. +func (s *ScheduledTestService) ListResults(ctx context.Context, planID int64, limit int) ([]*ScheduledTestResult, error) { + if limit <= 0 { + limit = 50 + } + return s.resultRepo.ListByPlanID(ctx, planID, limit) +} + +// SaveResult inserts a result and prunes old entries beyond maxResults. +func (s *ScheduledTestService) SaveResult(ctx context.Context, planID int64, maxResults int, result *ScheduledTestResult) error { + result.PlanID = planID + if _, err := s.resultRepo.Create(ctx, result); err != nil { + return err + } + return s.resultRepo.PruneOldResults(ctx, planID, maxResults) +} + +func computeNextRun(cronExpr string, from time.Time) (time.Time, error) { + sched, err := scheduledTestCronParser.Parse(cronExpr) + if err != nil { + return time.Time{}, err + } + return sched.Next(from), nil +} diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 9f8fa14a..4c9540f1 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -605,8 +605,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke var err error if groupID > 0 { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms) - } else { + } else if s.isRunModeSimple() { accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms) } if err != nil { return nil, err @@ -624,7 +626,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke if groupID > 0 { return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform) } - return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform) + if s.isRunModeSimple() { + return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform) + } + return s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, bucket.Platform) } func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index c708e061..6cb13b11 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "log/slog" + "net/url" "strconv" "strings" "sync/atomic" @@ -64,6 +65,19 @@ const minVersionErrorTTL = 5 * time.Second // minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context const minVersionDBTimeout = 5 * time.Second +// cachedBackendMode Backend Mode cache (in-process, 60s TTL) +type cachedBackendMode struct { + value bool + expiresAt int64 // unix nano +} + +var backendModeCache atomic.Value // *cachedBackendMode +var backendModeSF singleflight.Group + +const backendModeCacheTTL = 60 * time.Second +const backendModeErrorTTL = 5 * time.Second +const backendModeDBTimeout = 5 * time.Second + // DefaultSubscriptionGroupReader validates group references used by default subscriptions. type DefaultSubscriptionGroupReader interface { GetByID(ctx context.Context, id int64) (*Group, error) @@ -102,11 +116,21 @@ func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, e return s.parseSettings(settings), nil } +// GetFrontendURL 获取前端基础URL(数据库优先,fallback 到配置文件) +func (s *SettingService) GetFrontendURL(ctx context.Context) string { + val, err := s.settingRepo.GetValue(ctx, SettingKeyFrontendURL) + if err == nil && strings.TrimSpace(val) != "" { + return strings.TrimSpace(val) + } + return s.cfg.Server.FrontendURL +} + // GetPublicSettings 获取公开设置(无需登录) func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings, error) { keys := []string{ SettingKeyRegistrationEnabled, SettingKeyEmailVerifyEnabled, + SettingKeyRegistrationEmailSuffixWhitelist, SettingKeyPromoCodeEnabled, SettingKeyPasswordResetEnabled, SettingKeyInvitationCodeEnabled, @@ -124,7 +148,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyPurchaseSubscriptionEnabled, SettingKeyPurchaseSubscriptionURL, SettingKeySoraClientEnabled, + SettingKeyCustomMenuItems, SettingKeyLinuxDoConnectEnabled, + SettingKeyBackendModeEnabled, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -142,28 +168,34 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings // Password reset requires email verification to be enabled emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true" + registrationEmailSuffixWhitelist := ParseRegistrationEmailSuffixWhitelist( + settings[SettingKeyRegistrationEmailSuffixWhitelist], + ) return &PublicSettings{ - RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: emailVerifyEnabled, - PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 - PasswordResetEnabled: passwordResetEnabled, - InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", - TotpEnabled: settings[SettingKeyTotpEnabled] == "true", - TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", - TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], - SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "TianShuAPI"), - SiteLogo: settings[SettingKeySiteLogo], - SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), - APIBaseURL: settings[SettingKeyAPIBaseURL], - ContactInfo: settings[SettingKeyContactInfo], - DocURL: settings[SettingKeyDocURL], - HomeContent: settings[SettingKeyHomeContent], - HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", - PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", - PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), - SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", - LinuxDoOAuthEnabled: linuxDoEnabled, + RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", + EmailVerifyEnabled: emailVerifyEnabled, + RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist, + PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 + PasswordResetEnabled: passwordResetEnabled, + InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", + TotpEnabled: settings[SettingKeyTotpEnabled] == "true", + TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", + TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], + SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "TianShuAPI"), + SiteLogo: settings[SettingKeySiteLogo], + SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), + APIBaseURL: settings[SettingKeyAPIBaseURL], + ContactInfo: settings[SettingKeyContactInfo], + DocURL: settings[SettingKeyDocURL], + HomeContent: settings[SettingKeyHomeContent], + HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", + PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", + PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", + CustomMenuItems: settings[SettingKeyCustomMenuItems], + LinuxDoOAuthEnabled: linuxDoEnabled, + BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", }, nil } @@ -193,65 +225,192 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any // Return a struct that matches the frontend's expected format return &struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo,omitempty"` - SiteSubtitle string `json:"site_subtitle,omitempty"` - APIBaseURL string `json:"api_base_url,omitempty"` - ContactInfo string `json:"contact_info,omitempty"` - DocURL string `json:"doc_url,omitempty"` - HomeContent string `json:"home_content,omitempty"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` - SoraClientEnabled bool `json:"sora_client_enabled"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - Version string `json:"version,omitempty"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo,omitempty"` + SiteSubtitle string `json:"site_subtitle,omitempty"` + APIBaseURL string `json:"api_base_url,omitempty"` + ContactInfo string `json:"contact_info,omitempty"` + DocURL string `json:"doc_url,omitempty"` + HomeContent string `json:"home_content,omitempty"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` + SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems json.RawMessage `json:"custom_menu_items"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + BackendModeEnabled bool `json:"backend_mode_enabled"` + Version string `json:"version,omitempty"` }{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - PromoCodeEnabled: settings.PromoCodeEnabled, - PasswordResetEnabled: settings.PasswordResetEnabled, - InvitationCodeEnabled: settings.InvitationCodeEnabled, - TotpEnabled: settings.TotpEnabled, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - HomeContent: settings.HomeContent, - HideCcsImportButton: settings.HideCcsImportButton, - PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, - PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - SoraClientEnabled: settings.SoraClientEnabled, - LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, - Version: s.version, + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: settings.PromoCodeEnabled, + PasswordResetEnabled: settings.PasswordResetEnabled, + InvitationCodeEnabled: settings.InvitationCodeEnabled, + TotpEnabled: settings.TotpEnabled, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, + PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, + PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + SoraClientEnabled: settings.SoraClientEnabled, + CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + BackendModeEnabled: settings.BackendModeEnabled, + Version: s.version, }, nil } +// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON +// array string, returning only items with visibility != "admin". +func filterUserVisibleMenuItems(raw string) json.RawMessage { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return json.RawMessage("[]") + } + var items []struct { + Visibility string `json:"visibility"` + } + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return json.RawMessage("[]") + } + + // Parse full items to preserve all fields + var fullItems []json.RawMessage + if err := json.Unmarshal([]byte(raw), &fullItems); err != nil { + return json.RawMessage("[]") + } + + var filtered []json.RawMessage + for i, item := range items { + if item.Visibility != "admin" { + filtered = append(filtered, fullItems[i]) + } + } + if len(filtered) == 0 { + return json.RawMessage("[]") + } + result, err := json.Marshal(filtered) + if err != nil { + return json.RawMessage("[]") + } + return result +} + +// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url +// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. +func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { + settings, err := s.GetPublicSettings(ctx) + if err != nil { + return nil, err + } + + seen := make(map[string]struct{}) + var origins []string + + addOrigin := func(rawURL string) { + if origin := extractOriginFromURL(rawURL); origin != "" { + if _, ok := seen[origin]; !ok { + seen[origin] = struct{}{} + origins = append(origins, origin) + } + } + } + + // purchase subscription URL + if settings.PurchaseSubscriptionEnabled { + addOrigin(settings.PurchaseSubscriptionURL) + } + + // all custom menu items (including admin-only, since CSP must allow all iframes) + for _, item := range parseCustomMenuItemURLs(settings.CustomMenuItems) { + addOrigin(item) + } + + return origins, nil +} + +// extractOriginFromURL returns the scheme+host origin from rawURL. +// Only http and https schemes are accepted. +func extractOriginFromURL(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "" + } + u, err := url.Parse(rawURL) + if err != nil || u.Host == "" { + return "" + } + if u.Scheme != "http" && u.Scheme != "https" { + return "" + } + return u.Scheme + "://" + u.Host +} + +// parseCustomMenuItemURLs extracts URLs from a raw JSON array of custom menu items. +func parseCustomMenuItemURLs(raw string) []string { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return nil + } + var items []struct { + URL string `json:"url"` + } + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil + } + urls := make([]string, 0, len(items)) + for _, item := range items { + if item.URL != "" { + urls = append(urls, item.URL) + } + } + return urls +} + // UpdateSettings 更新系统设置 func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { return err } + normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist) + if err != nil { + return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error()) + } + if normalizedWhitelist == nil { + normalizedWhitelist = []string{} + } + settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist updates := make(map[string]string) // 注册设置 updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled) updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) + registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist) + if err != nil { + return fmt.Errorf("marshal registration email suffix whitelist: %w", err) + } + updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON) updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled) + updates[SettingKeyFrontendURL] = settings.FrontendURL updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled) updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled) @@ -293,6 +452,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) + updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) @@ -325,6 +485,12 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet // Claude Code version check updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion + // 分组隔离 + updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling) + + // Backend Mode + updates[SettingKeyBackendModeEnabled] = strconv.FormatBool(settings.BackendModeEnabled) + err = s.settingRepo.SetMultiple(ctx, updates) if err == nil { // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 @@ -333,6 +499,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet value: settings.MinClaudeCodeVersion, expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(), }) + backendModeSF.Forget("backend_mode") + backendModeCache.Store(&cachedBackendMode{ + value: settings.BackendModeEnabled, + expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), + }) if s.onUpdate != nil { s.onUpdate() // Invalidate cache after settings update } @@ -389,6 +560,52 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { return value == "true" } +// IsBackendModeEnabled checks if backend mode is enabled +// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path +func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool { + if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.value + } + } + result, _, _ := backendModeSF.Do("backend_mode", func() (any, error) { + if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.value, nil + } + } + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), backendModeDBTimeout) + defer cancel() + value, err := s.settingRepo.GetValue(dbCtx, SettingKeyBackendModeEnabled) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + // Setting not yet created (fresh install) - default to disabled with full TTL + backendModeCache.Store(&cachedBackendMode{ + value: false, + expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), + }) + return false, nil + } + slog.Warn("failed to get backend_mode_enabled setting", "error", err) + backendModeCache.Store(&cachedBackendMode{ + value: false, + expiresAt: time.Now().Add(backendModeErrorTTL).UnixNano(), + }) + return false, nil + } + enabled := value == "true" + backendModeCache.Store(&cachedBackendMode{ + value: enabled, + expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), + }) + return enabled, nil + }) + if val, ok := result.(bool); ok { + return val + } + return false +} + // IsEmailVerifyEnabled 检查是否开启邮件验证 func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled) @@ -398,6 +615,15 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool { return value == "true" } +// GetRegistrationEmailSuffixWhitelist returns normalized registration email suffix whitelist. +func (s *SettingService) GetRegistrationEmailSuffixWhitelist(ctx context.Context) []string { + value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEmailSuffixWhitelist) + if err != nil { + return []string{} + } + return ParseRegistrationEmailSuffixWhitelist(value) +} + // IsPromoCodeEnabled 检查是否启用优惠码功能 func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyPromoCodeEnabled) @@ -501,19 +727,21 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 初始化默认设置 defaults := map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyEmailVerifyEnabled: "false", - SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 - SettingKeySiteName: "TianShuAPI", - SettingKeySiteLogo: "", - SettingKeyPurchaseSubscriptionEnabled: "false", - SettingKeyPurchaseSubscriptionURL: "", - SettingKeySoraClientEnabled: "false", - SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), - SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), - SettingKeyDefaultSubscriptions: "[]", - SettingKeySMTPPort: "587", - SettingKeySMTPUseTLS: "false", + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "false", + SettingKeyRegistrationEmailSuffixWhitelist: "[]", + SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 + SettingKeySiteName: "TianShuAPI", + SettingKeySiteLogo: "", + SettingKeyPurchaseSubscriptionEnabled: "false", + SettingKeyPurchaseSubscriptionURL: "", + SettingKeySoraClientEnabled: "false", + SettingKeyCustomMenuItems: "[]", + SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), + SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyDefaultSubscriptions: "[]", + SettingKeySMTPPort: "587", + SettingKeySMTPUseTLS: "false", // Model fallback defaults SettingKeyEnableModelFallback: "false", SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022", @@ -532,6 +760,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // Claude Code version check (default: empty = disabled) SettingKeyMinClaudeCodeVersion: "", + + // 分组隔离(默认不允许未分组 Key 调度) + SettingKeyAllowUngroupedKeyScheduling: "false", } return s.settingRepo.SetMultiple(ctx, defaults) @@ -541,32 +772,36 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings { emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" result := &SystemSettings{ - RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: emailVerifyEnabled, - PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 - PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true", - InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", - TotpEnabled: settings[SettingKeyTotpEnabled] == "true", - SMTPHost: settings[SettingKeySMTPHost], - SMTPUsername: settings[SettingKeySMTPUsername], - SMTPFrom: settings[SettingKeySMTPFrom], - SMTPFromName: settings[SettingKeySMTPFromName], - SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", - SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "", - TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", - TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], - TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "", - SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "TianShuAPI"), - SiteLogo: settings[SettingKeySiteLogo], - SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), - APIBaseURL: settings[SettingKeyAPIBaseURL], - ContactInfo: settings[SettingKeyContactInfo], - DocURL: settings[SettingKeyDocURL], - HomeContent: settings[SettingKeyHomeContent], - HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", - PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", - PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), - SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", + RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", + EmailVerifyEnabled: emailVerifyEnabled, + RegistrationEmailSuffixWhitelist: ParseRegistrationEmailSuffixWhitelist(settings[SettingKeyRegistrationEmailSuffixWhitelist]), + PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 + PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true", + FrontendURL: settings[SettingKeyFrontendURL], + InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", + TotpEnabled: settings[SettingKeyTotpEnabled] == "true", + SMTPHost: settings[SettingKeySMTPHost], + SMTPUsername: settings[SettingKeySMTPUsername], + SMTPFrom: settings[SettingKeySMTPFrom], + SMTPFromName: settings[SettingKeySMTPFromName], + SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", + SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "", + TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", + TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], + TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "", + SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "TianShuAPI"), + SiteLogo: settings[SettingKeySiteLogo], + SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), + APIBaseURL: settings[SettingKeyAPIBaseURL], + ContactInfo: settings[SettingKeyContactInfo], + DocURL: settings[SettingKeyDocURL], + HomeContent: settings[SettingKeyHomeContent], + HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", + PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", + PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", + CustomMenuItems: settings[SettingKeyCustomMenuItems], + BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", } // 解析整数类型 @@ -661,6 +896,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin // Claude Code version check result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion] + // 分组隔离 + result.AllowUngroupedKeyScheduling = settings[SettingKeyAllowUngroupedKeyScheduling] == "true" + return result } @@ -983,6 +1221,15 @@ func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamT return &settings, nil } +// IsUngroupedKeySchedulingAllowed 查询是否允许未分组 Key 调度 +func (s *SettingService) IsUngroupedKeySchedulingAllowed(ctx context.Context) bool { + value, err := s.settingRepo.GetValue(ctx, SettingKeyAllowUngroupedKeyScheduling) + if err != nil { + return false // fail-closed: 查询失败时默认不允许 + } + return value == "true" +} + // GetMinClaudeCodeVersion 获取最低 Claude Code 版本号要求 // 使用进程内 atomic.Value 缓存,60 秒 TTL,热路径零锁开销 // singleflight 防止缓存过期时 thundering herd @@ -994,7 +1241,7 @@ func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string { } } // singleflight: 同一时刻只有一个 goroutine 查询 DB,其余复用结果 - result, _, _ := minVersionSF.Do("min_version", func() (any, error) { + result, err, _ := minVersionSF.Do("min_version", func() (any, error) { // 二次检查,避免排队的 goroutine 重复查询 if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok { if time.Now().UnixNano() < cached.expiresAt { @@ -1020,10 +1267,121 @@ func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string { }) return value, nil }) - if s, ok := result.(string); ok { - return s + if err != nil { + return "" } - return "" + ver, ok := result.(string) + if !ok { + return "" + } + return ver +} + +// GetRectifierSettings 获取请求整流器配置 +func (s *SettingService) GetRectifierSettings(ctx context.Context) (*RectifierSettings, error) { + value, err := s.settingRepo.GetValue(ctx, SettingKeyRectifierSettings) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return DefaultRectifierSettings(), nil + } + return nil, fmt.Errorf("get rectifier settings: %w", err) + } + if value == "" { + return DefaultRectifierSettings(), nil + } + + var settings RectifierSettings + if err := json.Unmarshal([]byte(value), &settings); err != nil { + return DefaultRectifierSettings(), nil + } + + return &settings, nil +} + +// SetRectifierSettings 设置请求整流器配置 +func (s *SettingService) SetRectifierSettings(ctx context.Context, settings *RectifierSettings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + data, err := json.Marshal(settings) + if err != nil { + return fmt.Errorf("marshal rectifier settings: %w", err) + } + + return s.settingRepo.Set(ctx, SettingKeyRectifierSettings, string(data)) +} + +// IsSignatureRectifierEnabled 判断签名整流是否启用(总开关 && 签名子开关) +func (s *SettingService) IsSignatureRectifierEnabled(ctx context.Context) bool { + settings, err := s.GetRectifierSettings(ctx) + if err != nil { + return true // fail-open: 查询失败时默认启用 + } + return settings.Enabled && settings.ThinkingSignatureEnabled +} + +// IsBudgetRectifierEnabled 判断 Budget 整流是否启用(总开关 && Budget 子开关) +func (s *SettingService) IsBudgetRectifierEnabled(ctx context.Context) bool { + settings, err := s.GetRectifierSettings(ctx) + if err != nil { + return true // fail-open: 查询失败时默认启用 + } + return settings.Enabled && settings.ThinkingBudgetEnabled +} + +// GetBetaPolicySettings 获取 Beta 策略配置 +func (s *SettingService) GetBetaPolicySettings(ctx context.Context) (*BetaPolicySettings, error) { + value, err := s.settingRepo.GetValue(ctx, SettingKeyBetaPolicySettings) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return DefaultBetaPolicySettings(), nil + } + return nil, fmt.Errorf("get beta policy settings: %w", err) + } + if value == "" { + return DefaultBetaPolicySettings(), nil + } + + var settings BetaPolicySettings + if err := json.Unmarshal([]byte(value), &settings); err != nil { + return DefaultBetaPolicySettings(), nil + } + + return &settings, nil +} + +// SetBetaPolicySettings 设置 Beta 策略配置 +func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *BetaPolicySettings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + validActions := map[string]bool{ + BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true, + } + validScopes := map[string]bool{ + BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true, + } + + for i, rule := range settings.Rules { + if rule.BetaToken == "" { + return fmt.Errorf("rule[%d]: beta_token cannot be empty", i) + } + if !validActions[rule.Action] { + return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action) + } + if !validScopes[rule.Scope] { + return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope) + } + } + + data, err := json.Marshal(settings) + if err != nil { + return fmt.Errorf("marshal beta policy settings: %w", err) + } + + return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data)) } // SetStreamTimeoutSettings 设置流超时处理配置 diff --git a/backend/internal/service/setting_service_backend_mode_test.go b/backend/internal/service/setting_service_backend_mode_test.go new file mode 100644 index 00000000..39922ec8 --- /dev/null +++ b/backend/internal/service/setting_service_backend_mode_test.go @@ -0,0 +1,199 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type bmRepoStub struct { + getValueFn func(ctx context.Context, key string) (string, error) + calls int +} + +func (s *bmRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *bmRepoStub) GetValue(ctx context.Context, key string) (string, error) { + s.calls++ + if s.getValueFn == nil { + panic("unexpected GetValue call") + } + return s.getValueFn(ctx, key) +} + +func (s *bmRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *bmRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *bmRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *bmRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *bmRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +type bmUpdateRepoStub struct { + updates map[string]string + getValueFn func(ctx context.Context, key string) (string, error) +} + +func (s *bmUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *bmUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) { + if s.getValueFn == nil { + panic("unexpected GetValue call") + } + return s.getValueFn(ctx, key) +} + +func (s *bmUpdateRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *bmUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *bmUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.updates = make(map[string]string, len(settings)) + for k, v := range settings { + s.updates[k] = v + } + return nil +} + +func (s *bmUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *bmUpdateRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func resetBackendModeTestCache(t *testing.T) { + t.Helper() + + backendModeCache.Store((*cachedBackendMode)(nil)) + t.Cleanup(func() { + backendModeCache.Store((*cachedBackendMode)(nil)) + }) +} + +func TestIsBackendModeEnabled_ReturnsTrue(t *testing.T) { + resetBackendModeTestCache(t) + + repo := &bmRepoStub{ + getValueFn: func(ctx context.Context, key string) (string, error) { + require.Equal(t, SettingKeyBackendModeEnabled, key) + return "true", nil + }, + } + svc := NewSettingService(repo, &config.Config{}) + + require.True(t, svc.IsBackendModeEnabled(context.Background())) + require.Equal(t, 1, repo.calls) +} + +func TestIsBackendModeEnabled_ReturnsFalse(t *testing.T) { + resetBackendModeTestCache(t) + + repo := &bmRepoStub{ + getValueFn: func(ctx context.Context, key string) (string, error) { + require.Equal(t, SettingKeyBackendModeEnabled, key) + return "false", nil + }, + } + svc := NewSettingService(repo, &config.Config{}) + + require.False(t, svc.IsBackendModeEnabled(context.Background())) + require.Equal(t, 1, repo.calls) +} + +func TestIsBackendModeEnabled_ReturnsFalseOnNotFound(t *testing.T) { + resetBackendModeTestCache(t) + + repo := &bmRepoStub{ + getValueFn: func(ctx context.Context, key string) (string, error) { + require.Equal(t, SettingKeyBackendModeEnabled, key) + return "", ErrSettingNotFound + }, + } + svc := NewSettingService(repo, &config.Config{}) + + require.False(t, svc.IsBackendModeEnabled(context.Background())) + require.Equal(t, 1, repo.calls) +} + +func TestIsBackendModeEnabled_ReturnsFalseOnDBError(t *testing.T) { + resetBackendModeTestCache(t) + + repo := &bmRepoStub{ + getValueFn: func(ctx context.Context, key string) (string, error) { + require.Equal(t, SettingKeyBackendModeEnabled, key) + return "", errors.New("db down") + }, + } + svc := NewSettingService(repo, &config.Config{}) + + require.False(t, svc.IsBackendModeEnabled(context.Background())) + require.Equal(t, 1, repo.calls) +} + +func TestIsBackendModeEnabled_CachesResult(t *testing.T) { + resetBackendModeTestCache(t) + + repo := &bmRepoStub{ + getValueFn: func(ctx context.Context, key string) (string, error) { + require.Equal(t, SettingKeyBackendModeEnabled, key) + return "true", nil + }, + } + svc := NewSettingService(repo, &config.Config{}) + + require.True(t, svc.IsBackendModeEnabled(context.Background())) + require.True(t, svc.IsBackendModeEnabled(context.Background())) + require.Equal(t, 1, repo.calls) +} + +func TestUpdateSettings_InvalidatesBackendModeCache(t *testing.T) { + resetBackendModeTestCache(t) + + backendModeCache.Store(&cachedBackendMode{ + value: true, + expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), + }) + + repo := &bmUpdateRepoStub{ + getValueFn: func(ctx context.Context, key string) (string, error) { + require.Equal(t, SettingKeyBackendModeEnabled, key) + return "true", nil + }, + } + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + BackendModeEnabled: false, + }) + require.NoError(t, err) + require.Equal(t, "false", repo.updates[SettingKeyBackendModeEnabled]) + require.False(t, svc.IsBackendModeEnabled(context.Background())) +} diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go new file mode 100644 index 00000000..b511cd29 --- /dev/null +++ b/backend/internal/service/setting_service_public_test.go @@ -0,0 +1,64 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type settingPublicRepoStub struct { + values map[string]string +} + +func (s *settingPublicRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *settingPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingPublicRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingPublicRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelist(t *testing.T) { + repo := &settingPublicRepoStub{ + values: map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + SettingKeyRegistrationEmailSuffixWhitelist: `["@EXAMPLE.com"," @foo.bar ","@invalid_domain",""]`, + }, + } + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, []string{"@example.com", "@foo.bar"}, settings.RegistrationEmailSuffixWhitelist) +} diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go index ec64511f..1de08611 100644 --- a/backend/internal/service/setting_service_update_test.go +++ b/backend/internal/service/setting_service_update_test.go @@ -172,6 +172,28 @@ func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGrou require.Nil(t, repo.updates) } +func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Normalized(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + RegistrationEmailSuffixWhitelist: []string{"example.com", "@EXAMPLE.com", " @foo.bar "}, + }) + require.NoError(t, err) + require.Equal(t, `["@example.com","@foo.bar"]`, repo.updates[SettingKeyRegistrationEmailSuffixWhitelist]) +} + +func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Invalid(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + RegistrationEmailSuffixWhitelist: []string{"@invalid_domain"}, + }) + require.Error(t, err) + require.Equal(t, "INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", infraerrors.Reason(err)) +} + func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) { got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`) require.Equal(t, []DefaultSubscriptionSetting{ diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 5a441ea1..71c2e7aa 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -1,12 +1,14 @@ package service type SystemSettings struct { - RegistrationEnabled bool - EmailVerifyEnabled bool - PromoCodeEnabled bool - PasswordResetEnabled bool - InvitationCodeEnabled bool - TotpEnabled bool // TOTP 双因素认证 + RegistrationEnabled bool + EmailVerifyEnabled bool + RegistrationEmailSuffixWhitelist []string + PromoCodeEnabled bool + PasswordResetEnabled bool + FrontendURL string + InvitationCodeEnabled bool + TotpEnabled bool // TOTP 双因素认证 SMTPHost string SMTPPort int @@ -40,6 +42,7 @@ type SystemSettings struct { PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string SoraClientEnabled bool + CustomMenuItems string // JSON array of custom menu items DefaultConcurrency int DefaultBalance float64 @@ -64,6 +67,12 @@ type SystemSettings struct { // Claude Code version check MinClaudeCodeVersion string + + // 分组隔离:允许未分组 Key 调度(默认 false → 403) + AllowUngroupedKeyScheduling bool + + // Backend 模式:禁用用户注册和自助服务,仅管理员可登录 + BackendModeEnabled bool } type DefaultSubscriptionSetting struct { @@ -72,28 +81,31 @@ type DefaultSubscriptionSetting struct { } type PublicSettings struct { - RegistrationEnabled bool - EmailVerifyEnabled bool - PromoCodeEnabled bool - PasswordResetEnabled bool - InvitationCodeEnabled bool - TotpEnabled bool // TOTP 双因素认证 - TurnstileEnabled bool - TurnstileSiteKey string - SiteName string - SiteLogo string - SiteSubtitle string - APIBaseURL string - ContactInfo string - DocURL string - HomeContent string - HideCcsImportButton bool + RegistrationEnabled bool + EmailVerifyEnabled bool + RegistrationEmailSuffixWhitelist []string + PromoCodeEnabled bool + PasswordResetEnabled bool + InvitationCodeEnabled bool + TotpEnabled bool // TOTP 双因素认证 + TurnstileEnabled bool + TurnstileSiteKey string + SiteName string + SiteLogo string + SiteSubtitle string + APIBaseURL string + ContactInfo string + DocURL string + HomeContent string + HideCcsImportButton bool PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string SoraClientEnabled bool + CustomMenuItems string // JSON array of custom menu items LinuxDoOAuthEnabled bool + BackendModeEnabled bool Version string } @@ -168,3 +180,62 @@ func DefaultStreamTimeoutSettings() *StreamTimeoutSettings { ThresholdWindowMinutes: 10, } } + +// RectifierSettings 请求整流器配置 +type RectifierSettings struct { + Enabled bool `json:"enabled"` // 总开关 + ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` // Thinking 签名整流 + ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` // Thinking Budget 整流 +} + +// DefaultRectifierSettings 返回默认的整流器配置(全部启用) +func DefaultRectifierSettings() *RectifierSettings { + return &RectifierSettings{ + Enabled: true, + ThinkingSignatureEnabled: true, + ThinkingBudgetEnabled: true, + } +} + +// Beta Policy 策略常量 +const ( + BetaPolicyActionPass = "pass" // 透传,不做任何处理 + BetaPolicyActionFilter = "filter" // 过滤,从 beta header 中移除该 token + BetaPolicyActionBlock = "block" // 拦截,直接返回错误 + + BetaPolicyScopeAll = "all" // 所有账号类型 + BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号 + BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号 + BetaPolicyScopeBedrock = "bedrock" // 仅 AWS Bedrock 账号 +) + +// BetaPolicyRule 单条 Beta 策略规则 +type BetaPolicyRule struct { + BetaToken string `json:"beta_token"` // beta token 值 + Action string `json:"action"` // "pass" | "filter" | "block" + Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock" + ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效) +} + +// BetaPolicySettings Beta 策略配置 +type BetaPolicySettings struct { + Rules []BetaPolicyRule `json:"rules"` +} + +// DefaultBetaPolicySettings 返回默认的 Beta 策略配置 +func DefaultBetaPolicySettings() *BetaPolicySettings { + return &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "fast-mode-2026-02-01", + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + }, + { + BetaToken: "context-1m-2025-08-07", + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + }, + }, + } +} diff --git a/backend/internal/service/subscription_calculate_progress_test.go b/backend/internal/service/subscription_calculate_progress_test.go index 22018bcd..53e5c568 100644 --- a/backend/internal/service/subscription_calculate_progress_test.go +++ b/backend/internal/service/subscription_calculate_progress_test.go @@ -34,7 +34,7 @@ func TestCalculateProgress_BasicFields(t *testing.T) { assert.Equal(t, int64(100), progress.ID) assert.Equal(t, "Premium", progress.GroupName) assert.Equal(t, sub.ExpiresAt, progress.ExpiresAt) - assert.Equal(t, 29, progress.ExpiresInDays) // 约 30 天 + assert.True(t, progress.ExpiresInDays == 29 || progress.ExpiresInDays == 30, "ExpiresInDays should be 29 or 30, got %d", progress.ExpiresInDays) assert.Nil(t, progress.Daily, "无日限额时 Daily 应为 nil") assert.Nil(t, progress.Weekly, "无周限额时 Weekly 应为 nil") assert.Nil(t, progress.Monthly, "无月限额时 Monthly 应为 nil") diff --git a/backend/internal/service/subscription_reset_quota_test.go b/backend/internal/service/subscription_reset_quota_test.go new file mode 100644 index 00000000..3bbc2170 --- /dev/null +++ b/backend/internal/service/subscription_reset_quota_test.go @@ -0,0 +1,207 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// resetQuotaUserSubRepoStub 支持 GetByID、ResetDailyUsage、ResetWeeklyUsage、ResetMonthlyUsage, +// 其余方法继承 userSubRepoNoop(panic)。 +type resetQuotaUserSubRepoStub struct { + userSubRepoNoop + + sub *UserSubscription + + resetDailyCalled bool + resetWeeklyCalled bool + resetMonthlyCalled bool + resetDailyErr error + resetWeeklyErr error + resetMonthlyErr error +} + +func (r *resetQuotaUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) { + if r.sub == nil || r.sub.ID != id { + return nil, ErrSubscriptionNotFound + } + cp := *r.sub + return &cp, nil +} + +func (r *resetQuotaUserSubRepoStub) ResetDailyUsage(_ context.Context, _ int64, windowStart time.Time) error { + r.resetDailyCalled = true + if r.resetDailyErr == nil && r.sub != nil { + r.sub.DailyUsageUSD = 0 + r.sub.DailyWindowStart = &windowStart + } + return r.resetDailyErr +} + +func (r *resetQuotaUserSubRepoStub) ResetWeeklyUsage(_ context.Context, _ int64, _ time.Time) error { + r.resetWeeklyCalled = true + return r.resetWeeklyErr +} + +func (r *resetQuotaUserSubRepoStub) ResetMonthlyUsage(_ context.Context, _ int64, _ time.Time) error { + r.resetMonthlyCalled = true + return r.resetMonthlyErr +} + +func newResetQuotaSvc(stub *resetQuotaUserSubRepoStub) *SubscriptionService { + return NewSubscriptionService(groupRepoNoop{}, stub, nil, nil, nil) +} + +func TestAdminResetQuota_ResetBoth(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 1, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + result, err := svc.AdminResetQuota(context.Background(), 1, true, true, false) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage") + require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage") + require.False(t, stub.resetMonthlyCalled, "不应调用 ResetMonthlyUsage") +} + +func TestAdminResetQuota_ResetDailyOnly(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 2, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + result, err := svc.AdminResetQuota(context.Background(), 2, true, false, false) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage") + require.False(t, stub.resetWeeklyCalled, "不应调用 ResetWeeklyUsage") + require.False(t, stub.resetMonthlyCalled, "不应调用 ResetMonthlyUsage") +} + +func TestAdminResetQuota_ResetWeeklyOnly(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 3, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + result, err := svc.AdminResetQuota(context.Background(), 3, false, true, false) + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, stub.resetDailyCalled, "不应调用 ResetDailyUsage") + require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage") + require.False(t, stub.resetMonthlyCalled, "不应调用 ResetMonthlyUsage") +} + +func TestAdminResetQuota_BothFalseReturnsError(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 7, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 7, false, false, false) + + require.ErrorIs(t, err, ErrInvalidInput) + require.False(t, stub.resetDailyCalled) + require.False(t, stub.resetWeeklyCalled) + require.False(t, stub.resetMonthlyCalled) +} + +func TestAdminResetQuota_SubscriptionNotFound(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{sub: nil} + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 999, true, true, true) + + require.ErrorIs(t, err, ErrSubscriptionNotFound) + require.False(t, stub.resetDailyCalled) + require.False(t, stub.resetWeeklyCalled) + require.False(t, stub.resetMonthlyCalled) +} + +func TestAdminResetQuota_ResetDailyUsageError(t *testing.T) { + dbErr := errors.New("db error") + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 4, UserID: 10, GroupID: 20}, + resetDailyErr: dbErr, + } + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 4, true, true, false) + + require.ErrorIs(t, err, dbErr) + require.True(t, stub.resetDailyCalled) + require.False(t, stub.resetWeeklyCalled, "daily 失败后不应继续调用 weekly") +} + +func TestAdminResetQuota_ResetWeeklyUsageError(t *testing.T) { + dbErr := errors.New("db error") + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 5, UserID: 10, GroupID: 20}, + resetWeeklyErr: dbErr, + } + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 5, false, true, false) + + require.ErrorIs(t, err, dbErr) + require.True(t, stub.resetWeeklyCalled) +} + +func TestAdminResetQuota_ResetMonthlyOnly(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 8, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + result, err := svc.AdminResetQuota(context.Background(), 8, false, false, true) + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, stub.resetDailyCalled, "不应调用 ResetDailyUsage") + require.False(t, stub.resetWeeklyCalled, "不应调用 ResetWeeklyUsage") + require.True(t, stub.resetMonthlyCalled, "应调用 ResetMonthlyUsage") +} + +func TestAdminResetQuota_ResetMonthlyUsageError(t *testing.T) { + dbErr := errors.New("db error") + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 9, UserID: 10, GroupID: 20}, + resetMonthlyErr: dbErr, + } + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 9, false, false, true) + + require.ErrorIs(t, err, dbErr) + require.True(t, stub.resetMonthlyCalled) +} + +func TestAdminResetQuota_ReturnsRefreshedSub(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ + ID: 6, + UserID: 10, + GroupID: 20, + DailyUsageUSD: 99.9, + }, + } + + svc := newResetQuotaSvc(stub) + result, err := svc.AdminResetQuota(context.Background(), 6, true, false, false) + + require.NoError(t, err) + // ResetDailyUsage stub 会将 sub.DailyUsageUSD 归零, + // 服务应返回第二次 GetByID 的刷新值而非初始的 99.9 + require.Equal(t, float64(0), result.DailyUsageUSD, "返回的订阅应反映已归零的用量") + require.True(t, stub.resetDailyCalled) +} diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 57e04266..af548509 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -31,6 +31,7 @@ var ( ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group") ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics") ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type") + ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily, resetWeekly, or resetMonthly must be true") ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded") ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded") ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded") @@ -695,6 +696,46 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *U return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart) } +// AdminResetQuota manually resets the daily, weekly, and/or monthly usage windows. +// Uses startOfDay(now) as the new window start, matching automatic resets. +func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly, resetMonthly bool) (*UserSubscription, error) { + if !resetDaily && !resetWeekly && !resetMonthly { + return nil, ErrInvalidInput + } + sub, err := s.userSubRepo.GetByID(ctx, subscriptionID) + if err != nil { + return nil, err + } + windowStart := startOfDay(time.Now()) + if resetDaily { + if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil { + return nil, err + } + } + if resetWeekly { + if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil { + return nil, err + } + } + if resetMonthly { + if err := s.userSubRepo.ResetMonthlyUsage(ctx, sub.ID, windowStart); err != nil { + return nil, err + } + } + // Invalidate L1 ristretto cache. Ristretto's Del() is asynchronous by design, + // so call Wait() immediately after to flush pending operations and guarantee + // the deleted key is not returned on the very next Get() call. + s.InvalidateSubCache(sub.UserID, sub.GroupID) + if s.subCacheL1 != nil { + s.subCacheL1.Wait() + } + if s.billingCacheService != nil { + _ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID) + } + // Return the refreshed subscription from DB + return s.userSubRepo.GetByID(ctx, subscriptionID) +} + // CheckAndResetWindows 检查并重置过期的窗口 func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error { // 使用当天零点作为新窗口起始时间 diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index a37e0d0a..cb8841b0 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "log/slog" "strings" @@ -16,9 +17,17 @@ import ( type TokenRefreshService struct { accountRepo AccountRepository refreshers []TokenRefresher + executors []OAuthRefreshExecutor // 与 refreshers 一一对应的 executor(带 CacheKey) + refreshPolicy BackgroundRefreshPolicy cfg *config.TokenRefreshConfig cacheInvalidator TokenCacheInvalidator - schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题 + schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题 + tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存 + refreshAPI *OAuthRefreshAPI // 统一刷新 API + + // OpenAI privacy: 刷新成功后检查并设置 training opt-out + privacyClientFactory PrivacyClientFactory + proxyRepo ProxyRepository stopCh chan struct{} wg sync.WaitGroup @@ -34,24 +43,39 @@ func NewTokenRefreshService( cacheInvalidator TokenCacheInvalidator, schedulerCache SchedulerCache, cfg *config.Config, + tempUnschedCache TempUnschedCache, ) *TokenRefreshService { s := &TokenRefreshService{ accountRepo: accountRepo, + refreshPolicy: DefaultBackgroundRefreshPolicy(), cfg: &cfg.TokenRefresh, cacheInvalidator: cacheInvalidator, schedulerCache: schedulerCache, + tempUnschedCache: tempUnschedCache, stopCh: make(chan struct{}), } openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo) openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts) - // 注册平台特定的刷新器 + claudeRefresher := NewClaudeTokenRefresher(oauthService) + geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService) + agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService) + + // 注册平台特定的刷新器(TokenRefresher 接口) s.refreshers = []TokenRefresher{ - NewClaudeTokenRefresher(oauthService), + claudeRefresher, openAIRefresher, - NewGeminiTokenRefresher(geminiOAuthService), - NewAntigravityTokenRefresher(antigravityOAuthService), + geminiRefresher, + agRefresher, + } + + // 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法) + s.executors = []OAuthRefreshExecutor{ + claudeRefresher, + openAIRefresher, + geminiRefresher, + agRefresher, } return s @@ -69,6 +93,22 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) { } } +// SetPrivacyDeps 注入 OpenAI privacy opt-out 所需依赖 +func (s *TokenRefreshService) SetPrivacyDeps(factory PrivacyClientFactory, proxyRepo ProxyRepository) { + s.privacyClientFactory = factory + s.proxyRepo = proxyRepo +} + +// SetRefreshAPI 注入统一的 OAuth 刷新 API +func (s *TokenRefreshService) SetRefreshAPI(api *OAuthRefreshAPI) { + s.refreshAPI = api +} + +// SetRefreshPolicy 注入后台刷新调用侧策略(用于显式化平台/场景差异行为)。 +func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) { + s.refreshPolicy = policy +} + // Start 启动后台刷新服务 func (s *TokenRefreshService) Start() { if !s.cfg.Enabled { @@ -135,13 +175,13 @@ func (s *TokenRefreshService) processRefresh() { totalAccounts := len(accounts) oauthAccounts := 0 // 可刷新的OAuth账号数 needsRefresh := 0 // 需要刷新的账号数 - refreshed, failed := 0, 0 + refreshed, failed, skipped := 0, 0, 0 for i := range accounts { account := &accounts[i] // 遍历所有刷新器,找到能处理此账号的 - for _, refresher := range s.refreshers { + for idx, refresher := range s.refreshers { if !refresher.CanRefresh(account) { continue } @@ -155,14 +195,24 @@ func (s *TokenRefreshService) processRefresh() { needsRefresh++ + // 获取对应的 executor + var executor OAuthRefreshExecutor + if idx < len(s.executors) { + executor = s.executors[idx] + } + // 执行刷新 - if err := s.refreshWithRetry(ctx, account, refresher); err != nil { - slog.Warn("token_refresh.account_refresh_failed", - "account_id", account.ID, - "account_name", account.Name, - "error", err, - ) - failed++ + if err := s.refreshWithRetry(ctx, account, refresher, executor, refreshWindow); err != nil { + if errors.Is(err, errRefreshSkipped) { + skipped++ + } else { + slog.Warn("token_refresh.account_refresh_failed", + "account_id", account.ID, + "account_name", account.Name, + "error", err, + ) + failed++ + } } else { slog.Info("token_refresh.account_refreshed", "account_id", account.ID, @@ -180,13 +230,14 @@ func (s *TokenRefreshService) processRefresh() { if needsRefresh == 0 && failed == 0 { slog.Debug("token_refresh.cycle_completed", "total", totalAccounts, "oauth", oauthAccounts, - "needs_refresh", needsRefresh, "refreshed", refreshed, "failed", failed) + "needs_refresh", needsRefresh, "refreshed", refreshed, "skipped", skipped, "failed", failed) } else { slog.Info("token_refresh.cycle_completed", "total", totalAccounts, "oauth", oauthAccounts, "needs_refresh", needsRefresh, "refreshed", refreshed, + "skipped", skipped, "failed", failed, ) } @@ -199,66 +250,47 @@ func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account } // refreshWithRetry 带重试的刷新 -func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher) error { +func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher, executor OAuthRefreshExecutor, refreshWindow time.Duration) error { var lastErr error for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ { - newCredentials, err := refresher.Refresh(ctx, account) + var newCredentials map[string]any + var err error - // 如果有新凭证,先更新(即使有错误也要保存 token) - if newCredentials != nil { - // 记录刷新版本时间戳,用于解决缓存一致性问题 - // TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入 - newCredentials["_token_version"] = time.Now().UnixMilli() - - account.Credentials = newCredentials - if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil { - return fmt.Errorf("failed to save credentials: %w", saveErr) + // 优先使用统一 API(带分布式锁 + DB 重读保护) + if s.refreshAPI != nil && executor != nil { + result, refreshErr := s.refreshAPI.RefreshIfNeeded(ctx, account, executor, refreshWindow) + if refreshErr != nil { + err = refreshErr + } else if result.LockHeld { + // 锁被其他 worker 持有,由调用侧策略决定如何计数 + return s.refreshPolicy.handleLockHeld() + } else if !result.Refreshed { + // 已被其他路径刷新,由调用侧策略决定如何计数 + return s.refreshPolicy.handleAlreadyRefreshed() + } else { + account = result.Account + _ = result.NewCredentials // 统一 API 已设置 _token_version 并更新 DB,无需重复操作 + } + } else { + // 降级:直接调用 refresher(兼容旧路径) + newCredentials, err = refresher.Refresh(ctx, account) + if newCredentials != nil { + newCredentials["_token_version"] = time.Now().UnixMilli() + account.Credentials = newCredentials + if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil { + return fmt.Errorf("failed to save credentials: %w", saveErr) + } } } if err == nil { - // Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态 - if account.Platform == PlatformAntigravity && - account.Status == StatusError && - strings.Contains(account.ErrorMessage, "missing_project_id:") { - if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil { - slog.Warn("token_refresh.clear_account_error_failed", - "account_id", account.ID, - "error", clearErr, - ) - } else { - slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID) - } - } - // 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理) - if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth { - if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil { - slog.Warn("token_refresh.invalidate_token_cache_failed", - "account_id", account.ID, - "error", err, - ) - } else { - slog.Debug("token_refresh.token_cache_invalidated", "account_id", account.ID) - } - } - // 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials - // 这解决了 token 刷新后调度器缓存数据不一致的问题(#445) - if s.schedulerCache != nil { - if err := s.schedulerCache.SetAccount(ctx, account); err != nil { - slog.Warn("token_refresh.sync_scheduler_cache_failed", - "account_id", account.ID, - "error", err, - ) - } else { - slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID) - } - } + s.postRefreshActions(ctx, account) return nil } - // Antigravity 账户:不可重试错误直接标记 error 状态并返回 - if account.Platform == PlatformAntigravity && isNonRetryableRefreshError(err) { + // 不可重试错误(invalid_grant/invalid_client 等)直接标记 error 状态并返回 + if isNonRetryableRefreshError(err) { errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err) if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil { slog.Error("token_refresh.set_error_status_failed", @@ -285,27 +317,81 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc } } - // Antigravity 账户:其他错误仅记录日志,不标记 error(可能是临时网络问题) - // 其他平台账户:重试失败后标记 error - if account.Platform == PlatformAntigravity { - slog.Warn("token_refresh.retry_exhausted_antigravity", - "account_id", account.ID, - "max_retries", s.cfg.MaxRetries, - "error", lastErr, - ) - } else { - errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr) - if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { - slog.Error("token_refresh.set_error_status_failed", - "account_id", account.ID, - "error", err, - ) - } - } + // 可重试错误耗尽:仅记录日志,不标记 error(可能是临时网络问题,下个周期继续重试) + slog.Warn("token_refresh.retry_exhausted", + "account_id", account.ID, + "platform", account.Platform, + "max_retries", s.cfg.MaxRetries, + "error", lastErr, + ) return lastErr } +// postRefreshActions 刷新成功后的后续动作(清除错误状态、缓存失效、调度器同步等) +func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *Account) { + // Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态 + if account.Platform == PlatformAntigravity && + account.Status == StatusError && + strings.Contains(account.ErrorMessage, "missing_project_id:") { + if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil { + slog.Warn("token_refresh.clear_account_error_failed", + "account_id", account.ID, + "error", clearErr, + ) + } else { + slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID) + } + } + // 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景) + if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { + if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil { + slog.Warn("token_refresh.clear_temp_unschedulable_failed", + "account_id", account.ID, + "error", clearErr, + ) + } else { + slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID) + } + // 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态 + if s.tempUnschedCache != nil { + if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil { + slog.Warn("token_refresh.clear_temp_unsched_cache_failed", + "account_id", account.ID, + "error", clearErr, + ) + } + } + } + // 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理) + if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth { + if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil { + slog.Warn("token_refresh.invalidate_token_cache_failed", + "account_id", account.ID, + "error", err, + ) + } else { + slog.Debug("token_refresh.token_cache_invalidated", "account_id", account.ID) + } + } + // 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials + if s.schedulerCache != nil { + if err := s.schedulerCache.SetAccount(ctx, account); err != nil { + slog.Warn("token_refresh.sync_scheduler_cache_failed", + "account_id", account.ID, + "error", err, + ) + } else { + slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID) + } + } + // OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则尝试关闭训练数据共享 + s.ensureOpenAIPrivacy(ctx, account) +} + +// errRefreshSkipped 表示刷新被跳过(锁竞争或已被其他路径刷新),不计入 failed 或 refreshed +var errRefreshSkipped = fmt.Errorf("refresh skipped") + // isNonRetryableRefreshError 判断是否为不可重试的刷新错误 // 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权 // 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误 @@ -328,3 +414,49 @@ func isNonRetryableRefreshError(err error) bool { } return false } + +// ensureOpenAIPrivacy 检查 OpenAI OAuth 账号是否已设置 privacy_mode, +// 未设置则调用 disableOpenAITraining 并持久化结果到 Extra。 +func (s *TokenRefreshService) ensureOpenAIPrivacy(ctx context.Context, account *Account) { + if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { + return + } + if s.privacyClientFactory == nil { + return + } + // 已设置过则跳过 + if account.Extra != nil { + if _, ok := account.Extra["privacy_mode"]; ok { + return + } + } + + token, _ := account.Credentials["access_token"].(string) + if token == "" { + return + } + + var proxyURL string + if account.ProxyID != nil && s.proxyRepo != nil { + if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil { + proxyURL = p.URL() + } + } + + mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL) + if mode == "" { + return + } + + if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode}); err != nil { + slog.Warn("token_refresh.update_privacy_mode_failed", + "account_id", account.ID, + "error", err, + ) + } else { + slog.Info("token_refresh.privacy_mode_set", + "account_id", account.ID, + "privacy_mode", mode, + ) + } +} diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index 8e16c6f5..f48de65e 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -14,10 +14,11 @@ import ( type tokenRefreshAccountRepo struct { mockAccountRepoForGemini - updateCalls int - setErrorCalls int - lastAccount *Account - updateErr error + updateCalls int + setErrorCalls int + clearTempCalls int + lastAccount *Account + updateErr error } func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { @@ -31,6 +32,11 @@ func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorM return nil } +func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { + r.clearTempCalls++ + return nil +} + type tokenCacheInvalidatorStub struct { calls int err error @@ -41,6 +47,23 @@ func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account return s.err } +type tempUnschedCacheStub struct { + deleteCalls int +} + +func (s *tempUnschedCacheStub) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error { + return nil +} + +func (s *tempUnschedCacheStub) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) { + return nil, nil +} + +func (s *tempUnschedCacheStub) DeleteTempUnsched(ctx context.Context, accountID int64) error { + s.deleteCalls++ + return nil +} + type tokenRefresherStub struct { credentials map[string]any err error @@ -61,6 +84,10 @@ func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map return r.credentials, nil } +func (r *tokenRefresherStub) CacheKey(account *Account) string { + return "test:stub:" + account.Platform +} + func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { repo := &tokenRefreshAccountRepo{} invalidator := &tokenCacheInvalidatorStub{} @@ -70,7 +97,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 5, Platform: PlatformGemini, @@ -82,7 +109,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, invalidator.calls) @@ -98,7 +125,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 6, Platform: PlatformGemini, @@ -110,7 +137,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, invalidator.calls) @@ -124,7 +151,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil) account := &Account{ ID: 7, Platform: PlatformGemini, @@ -136,7 +163,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) { }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) } @@ -151,7 +178,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 8, Platform: PlatformAntigravity, @@ -163,7 +190,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) { }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效 @@ -179,7 +206,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 9, Platform: PlatformGemini, @@ -191,7 +218,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) { }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效 @@ -207,7 +234,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 10, Platform: PlatformOpenAI, // OpenAI OAuth 账户 @@ -219,7 +246,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效 @@ -235,7 +262,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 11, Platform: PlatformGemini, @@ -247,14 +274,14 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.Error(t, err) require.Contains(t, err.Error(), "failed to save credentials") require.Equal(t, 1, repo.updateCalls) require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效 } -// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试刷新失败的情况 +// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试可重试错误耗尽不标记 error func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) { repo := &tokenRefreshAccountRepo{} invalidator := &tokenCacheInvalidatorStub{} @@ -264,7 +291,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) { RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 12, Platform: PlatformGemini, @@ -274,11 +301,11 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) { err: errors.New("refresh failed"), } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.Error(t, err) require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新 require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效 - require.Equal(t, 1, repo.setErrorCalls) // 应设置错误状态 + require.Equal(t, 0, repo.setErrorCalls) // 可重试错误耗尽不标记 error,下个周期继续重试 } // TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态 @@ -291,7 +318,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 13, Platform: PlatformAntigravity, @@ -301,7 +328,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin err: errors.New("network error"), // 可重试错误 } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.Error(t, err) require.Equal(t, 0, repo.updateCalls) require.Equal(t, 0, invalidator.calls) @@ -318,7 +345,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te RetryBackoffSeconds: 0, }, } - service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg) + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) account := &Account{ ID: 14, Platform: PlatformAntigravity, @@ -328,13 +355,84 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te err: errors.New("invalid_grant: token revoked"), // 不可重试错误 } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.Error(t, err) require.Equal(t, 0, repo.updateCalls) require.Equal(t, 0, invalidator.calls) require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态 } +// TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable 测试刷新成功后清除临时不可调度(DB + Redis) +func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + tempCache := &tempUnschedCacheStub{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache) + until := time.Now().Add(10 * time.Minute) + account := &Account{ + ID: 15, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + TempUnschedulableUntil: &until, + } + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "new-token", + }, + } + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 1, repo.clearTempCalls) // DB 清除 + require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除 +} + +// TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms 测试所有平台不可重试错误都 SetError +func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *testing.T) { + tests := []struct { + name string + platform string + }{ + {name: "gemini", platform: PlatformGemini}, + {name: "anthropic", platform: PlatformAnthropic}, + {name: "openai", platform: PlatformOpenAI}, + {name: "antigravity", platform: PlatformAntigravity}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 3, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + account := &Account{ + ID: 16, + Platform: tt.platform, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + err: errors.New("invalid_grant: token revoked"), + } + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.Error(t, err) + require.Equal(t, 1, repo.setErrorCalls) // 所有平台不可重试错误都应 SetError + }) + } +} + // TestIsNonRetryableRefreshError 测试不可重试错误判断 func TestIsNonRetryableRefreshError(t *testing.T) { tests := []struct { @@ -359,3 +457,212 @@ func TestIsNonRetryableRefreshError(t *testing.T) { }) } } + +// ========== Path A (refreshAPI) 测试用例 ========== + +// mockTokenCacheForRefreshAPI 用于 Path A 测试的 GeminiTokenCache mock +type mockTokenCacheForRefreshAPI struct { + lockResult bool + lockErr error + releaseCalls int +} + +func (m *mockTokenCacheForRefreshAPI) GetAccessToken(_ context.Context, _ string) (string, error) { + return "", errors.New("not cached") +} + +func (m *mockTokenCacheForRefreshAPI) SetAccessToken(_ context.Context, _ string, _ string, _ time.Duration) error { + return nil +} + +func (m *mockTokenCacheForRefreshAPI) DeleteAccessToken(_ context.Context, _ string) error { + return nil +} + +func (m *mockTokenCacheForRefreshAPI) AcquireRefreshLock(_ context.Context, _ string, _ time.Duration) (bool, error) { + return m.lockResult, m.lockErr +} + +func (m *mockTokenCacheForRefreshAPI) ReleaseRefreshLock(_ context.Context, _ string) error { + m.releaseCalls++ + return nil +} + +// buildPathAService 构建注入了 refreshAPI 的 service(Path A 测试辅助) +func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, invalidator TokenCacheInvalidator) (*TokenRefreshService, *tokenRefresherStub) { + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + refreshAPI := NewOAuthRefreshAPI(repo, cache) + service.SetRefreshAPI(refreshAPI) + + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "refreshed-token", + }, + } + return service, refresher +} + +// TestPathA_Success 统一 API 路径正常成功:刷新 + DB 更新 + postRefreshActions +func TestPathA_Success(t *testing.T) { + account := &Account{ + ID: 100, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + repo := &tokenRefreshAccountRepo{} + repo.accountsByID = map[int64]*Account{account.ID: account} + invalidator := &tokenCacheInvalidatorStub{} + cache := &mockTokenCacheForRefreshAPI{lockResult: true} + + service, refresher := buildPathAService(repo, cache, invalidator) + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCalls) // DB 更新被调用 + require.Equal(t, 1, invalidator.calls) // 缓存失效被调用 + require.Equal(t, 1, cache.releaseCalls) // 锁被释放 +} + +// TestPathA_LockHeld 锁被其他 worker 持有 → 返回 errRefreshSkipped +func TestPathA_LockHeld(t *testing.T) { + account := &Account{ + ID: 101, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + cache := &mockTokenCacheForRefreshAPI{lockResult: false} // 锁获取失败(被占) + + service, refresher := buildPathAService(repo, cache, invalidator) + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.ErrorIs(t, err, errRefreshSkipped) + require.Equal(t, 0, repo.updateCalls) // 不应更新 DB + require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效 +} + +// TestPathA_AlreadyRefreshed 二次检查发现已被其他路径刷新 → 返回 errRefreshSkipped +func TestPathA_AlreadyRefreshed(t *testing.T) { + // NeedsRefresh 返回 false → RefreshIfNeeded 返回 {Refreshed: false} + account := &Account{ + ID: 102, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + repo := &tokenRefreshAccountRepo{} + repo.accountsByID = map[int64]*Account{account.ID: account} + invalidator := &tokenCacheInvalidatorStub{} + cache := &mockTokenCacheForRefreshAPI{lockResult: true} + + service, _ := buildPathAService(repo, cache, invalidator) + + // 使用一个 NeedsRefresh 返回 false 的 stub + noRefreshNeeded := &tokenRefresherStub{ + credentials: map[string]any{"access_token": "token"}, + } + // 覆盖 NeedsRefresh 行为 — 我们需要一个新的 stub 类型 + alwaysFreshStub := &alwaysFreshRefresherStub{} + + err := service.refreshWithRetry(context.Background(), account, noRefreshNeeded, alwaysFreshStub, time.Hour) + require.ErrorIs(t, err, errRefreshSkipped) + require.Equal(t, 0, repo.updateCalls) + require.Equal(t, 0, invalidator.calls) +} + +// alwaysFreshRefresherStub 二次检查时认为不需要刷新(模拟已被其他路径刷新) +type alwaysFreshRefresherStub struct{} + +func (r *alwaysFreshRefresherStub) CanRefresh(_ *Account) bool { return true } +func (r *alwaysFreshRefresherStub) NeedsRefresh(_ *Account, _ time.Duration) bool { return false } +func (r *alwaysFreshRefresherStub) Refresh(_ context.Context, _ *Account) (map[string]any, error) { + return nil, errors.New("should not be called") +} +func (r *alwaysFreshRefresherStub) CacheKey(account *Account) string { + return "test:fresh:" + account.Platform +} + +// TestPathA_NonRetryableError 统一 API 路径返回不可重试错误 → SetError +func TestPathA_NonRetryableError(t *testing.T) { + account := &Account{ + ID: 103, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + repo := &tokenRefreshAccountRepo{} + repo.accountsByID = map[int64]*Account{account.ID: account} + invalidator := &tokenCacheInvalidatorStub{} + cache := &mockTokenCacheForRefreshAPI{lockResult: true} + + service, _ := buildPathAService(repo, cache, invalidator) + + refresher := &tokenRefresherStub{ + err: errors.New("invalid_grant: token revoked"), + } + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.Error(t, err) + require.Equal(t, 1, repo.setErrorCalls) // 应标记 error 状态 + require.Equal(t, 0, repo.updateCalls) // 不应更新 credentials + require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效 +} + +// TestPathA_RetryableErrorExhausted 统一 API 路径可重试错误耗尽 → 不标记 error +func TestPathA_RetryableErrorExhausted(t *testing.T) { + account := &Account{ + ID: 104, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + repo := &tokenRefreshAccountRepo{} + repo.accountsByID = map[int64]*Account{account.ID: account} + invalidator := &tokenCacheInvalidatorStub{} + cache := &mockTokenCacheForRefreshAPI{lockResult: true} + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 2, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + refreshAPI := NewOAuthRefreshAPI(repo, cache) + service.SetRefreshAPI(refreshAPI) + + refresher := &tokenRefresherStub{ + err: errors.New("network timeout"), + } + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.Error(t, err) + require.Equal(t, 0, repo.setErrorCalls) // 可重试错误不标记 error + require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新 + require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效 +} + +// TestPathA_DBUpdateFailed 统一 API 路径 DB 更新失败 → 返回 error,不执行 postRefreshActions +func TestPathA_DBUpdateFailed(t *testing.T) { + account := &Account{ + ID: 105, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + repo := &tokenRefreshAccountRepo{updateErr: errors.New("db connection lost")} + repo.accountsByID = map[int64]*Account{account.ID: account} + invalidator := &tokenCacheInvalidatorStub{} + cache := &mockTokenCacheForRefreshAPI{lockResult: true} + + service, refresher := buildPathAService(repo, cache, invalidator) + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.Error(t, err) + require.Contains(t, err.Error(), "DB update failed") + require.Equal(t, 1, repo.updateCalls) // DB 更新被尝试 + require.Equal(t, 0, invalidator.calls) // DB 失败时不应触发缓存失效 +} diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 0dd3cf45..5a214161 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -3,7 +3,6 @@ package service import ( "context" "log" - "strconv" "time" ) @@ -33,6 +32,11 @@ func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher { } } +// CacheKey 返回用于分布式锁的缓存键 +func (r *ClaudeTokenRefresher) CacheKey(account *Account) string { + return ClaudeTokenCacheKey(account) +} + // CanRefresh 检查是否能处理此账号 // 只处理 anthropic 平台的 oauth 类型账号 // setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新 @@ -59,24 +63,8 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m return nil, err } - // 保留现有credentials中的所有字段 - newCredentials := make(map[string]any) - for k, v := range account.Credentials { - newCredentials[k] = v - } - - // 只更新token相关字段 - // 注意:expires_at 和 expires_in 必须存为字符串,因为 GetCredential 只返回 string 类型 - newCredentials["access_token"] = tokenInfo.AccessToken - newCredentials["token_type"] = tokenInfo.TokenType - newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10) - newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) - if tokenInfo.RefreshToken != "" { - newCredentials["refresh_token"] = tokenInfo.RefreshToken - } - if tokenInfo.Scope != "" { - newCredentials["scope"] = tokenInfo.Scope - } + newCredentials := BuildClaudeAccountCredentials(tokenInfo) + newCredentials = MergeCredentials(account.Credentials, newCredentials) return newCredentials, nil } @@ -97,6 +85,11 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService, accountRepo } } +// CacheKey 返回用于分布式锁的缓存键 +func (r *OpenAITokenRefresher) CacheKey(account *Account) string { + return OpenAITokenCacheKey(account) +} + // SetSoraAccountRepo 设置 Sora 账号扩展表仓储 // 用于在 Token 刷新时同步更新 sora_accounts 表 // 如果未设置,syncLinkedSoraAccounts 只会更新 accounts.credentials @@ -137,13 +130,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m // 使用服务提供的方法构建新凭证,并保留原有字段 newCredentials := r.openaiOAuthService.BuildAccountCredentials(tokenInfo) - - // 保留原有credentials中非token相关字段 - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } + newCredentials = MergeCredentials(account.Credentials, newCredentials) // 异步同步关联的 Sora 账号(不阻塞主流程) if r.accountRepo != nil && r.syncLinkedSora { diff --git a/backend/internal/service/usage_billing.go b/backend/internal/service/usage_billing.go new file mode 100644 index 00000000..73b05743 --- /dev/null +++ b/backend/internal/service/usage_billing.go @@ -0,0 +1,110 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "strings" +) + +var ErrUsageBillingRequestIDRequired = errors.New("usage billing request_id is required") +var ErrUsageBillingRequestConflict = errors.New("usage billing request fingerprint conflict") + +// UsageBillingCommand describes one billable request that must be applied at most once. +type UsageBillingCommand struct { + RequestID string + APIKeyID int64 + RequestFingerprint string + RequestPayloadHash string + + UserID int64 + AccountID int64 + SubscriptionID *int64 + AccountType string + Model string + ServiceTier string + ReasoningEffort string + BillingType int8 + InputTokens int + OutputTokens int + CacheCreationTokens int + CacheReadTokens int + ImageCount int + MediaType string + + BalanceCost float64 + SubscriptionCost float64 + APIKeyQuotaCost float64 + APIKeyRateLimitCost float64 + AccountQuotaCost float64 +} + +func (c *UsageBillingCommand) Normalize() { + if c == nil { + return + } + c.RequestID = strings.TrimSpace(c.RequestID) + if strings.TrimSpace(c.RequestFingerprint) == "" { + c.RequestFingerprint = buildUsageBillingFingerprint(c) + } +} + +func buildUsageBillingFingerprint(c *UsageBillingCommand) string { + if c == nil { + return "" + } + raw := fmt.Sprintf( + "%d|%d|%d|%s|%s|%s|%s|%d|%d|%d|%d|%d|%d|%s|%d|%0.10f|%0.10f|%0.10f|%0.10f|%0.10f", + c.UserID, + c.AccountID, + c.APIKeyID, + strings.TrimSpace(c.AccountType), + strings.TrimSpace(c.Model), + strings.TrimSpace(c.ServiceTier), + strings.TrimSpace(c.ReasoningEffort), + c.BillingType, + c.InputTokens, + c.OutputTokens, + c.CacheCreationTokens, + c.CacheReadTokens, + c.ImageCount, + strings.TrimSpace(c.MediaType), + valueOrZero(c.SubscriptionID), + c.BalanceCost, + c.SubscriptionCost, + c.APIKeyQuotaCost, + c.APIKeyRateLimitCost, + c.AccountQuotaCost, + ) + if payloadHash := strings.TrimSpace(c.RequestPayloadHash); payloadHash != "" { + raw += "|" + payloadHash + } + sum := sha256.Sum256([]byte(raw)) + return hex.EncodeToString(sum[:]) +} + +func HashUsageRequestPayload(payload []byte) string { + if len(payload) == 0 { + return "" + } + sum := sha256.Sum256(payload) + return hex.EncodeToString(sum[:]) +} + +func valueOrZero(v *int64) int64 { + if v == nil { + return 0 + } + return *v +} + +type UsageBillingApplyResult struct { + Applied bool + APIKeyQuotaExhausted bool +} + +type UsageBillingRepository interface { + Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) +} diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go index 0fdbfd47..17f21bef 100644 --- a/backend/internal/service/usage_cleanup_service_test.go +++ b/backend/internal/service/usage_cleanup_service_test.go @@ -56,7 +56,8 @@ type cleanupRepoStub struct { } type dashboardRepoStub struct { - recomputeErr error + recomputeErr error + recomputeCalls int } func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error { @@ -64,6 +65,7 @@ func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time. } func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + s.recomputeCalls++ return s.recomputeErr } @@ -83,6 +85,10 @@ func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Ti return nil } +func (s *dashboardRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + return nil +} + func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { return nil } @@ -550,13 +556,14 @@ func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) { } func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) { + dashboardRepo := &dashboardRepoStub{recomputeErr: errors.New("recompute failed")} repo := &cleanupRepoStub{ deleteQueue: []cleanupDeleteResponse{ {deleted: 0}, }, } - dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{ - DashboardAgg: config.DashboardAggregationConfig{Enabled: false}, + dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{ + DashboardAgg: config.DashboardAggregationConfig{Enabled: true}, }) cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} svc := NewUsageCleanupService(repo, nil, dashboard, cfg) @@ -573,15 +580,17 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) { repo.mu.Lock() defer repo.mu.Unlock() require.Len(t, repo.markSucceeded, 1) + require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond) } func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) { + dashboardRepo := &dashboardRepoStub{} repo := &cleanupRepoStub{ deleteQueue: []cleanupDeleteResponse{ {deleted: 0}, }, } - dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{ + dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{ DashboardAgg: config.DashboardAggregationConfig{Enabled: true}, }) cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} @@ -599,6 +608,7 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) { repo.mu.Lock() defer repo.mu.Unlock() require.Len(t, repo.markSucceeded, 1) + require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond) } func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) { diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index c1a95541..7f1bef7f 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -98,9 +98,16 @@ type UsageLog struct { AccountID int64 RequestID string Model string - // ReasoningEffort is the request's reasoning effort level (OpenAI Responses API), - // e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable. + // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". + ServiceTier *string + // ReasoningEffort is the request's reasoning effort level. + // OpenAI: "low" / "medium" / "high" / "xhigh"; Claude: "low" / "medium" / "high" / "max". + // Nil means not provided / not applicable. ReasoningEffort *string + // InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions. + InboundEndpoint *string + // UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses. + UpstreamEndpoint *string GroupID *int64 SubscriptionID *int64 diff --git a/backend/internal/service/usage_log_create_result.go b/backend/internal/service/usage_log_create_result.go new file mode 100644 index 00000000..1cd84f44 --- /dev/null +++ b/backend/internal/service/usage_log_create_result.go @@ -0,0 +1,82 @@ +package service + +import "errors" + +type usageLogCreateDisposition int + +const ( + usageLogCreateDispositionUnknown usageLogCreateDisposition = iota + usageLogCreateDispositionNotPersisted + usageLogCreateDispositionDropped +) + +type UsageLogCreateError struct { + err error + disposition usageLogCreateDisposition +} + +func (e *UsageLogCreateError) Error() string { + if e == nil || e.err == nil { + return "usage log create error" + } + return e.err.Error() +} + +func (e *UsageLogCreateError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +func MarkUsageLogCreateNotPersisted(err error) error { + if err == nil { + return nil + } + return &UsageLogCreateError{ + err: err, + disposition: usageLogCreateDispositionNotPersisted, + } +} + +func MarkUsageLogCreateDropped(err error) error { + if err == nil { + return nil + } + return &UsageLogCreateError{ + err: err, + disposition: usageLogCreateDispositionDropped, + } +} + +func IsUsageLogCreateNotPersisted(err error) bool { + if err == nil { + return false + } + var target *UsageLogCreateError + if !errors.As(err, &target) { + return false + } + return target.disposition == usageLogCreateDispositionNotPersisted +} + +func IsUsageLogCreateDropped(err error) bool { + if err == nil { + return false + } + var target *UsageLogCreateError + if !errors.As(err, &target) { + return false + } + return target.disposition == usageLogCreateDispositionDropped +} + +func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool { + if inserted { + return true + } + if err == nil { + return false + } + return !IsUsageLogCreateNotPersisted(err) +} diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index f21a2855..d64f01e0 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -315,6 +315,15 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star return stats, nil } +// GetAPIKeyModelStats returns per-model usage stats for a specific API Key. +func (s *UsageService) GetAPIKeyModelStats(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) { + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, 0, apiKeyID, 0, 0, nil, nil, nil) + if err != nil { + return nil, fmt.Errorf("get api key model stats: %w", err) + } + return stats, nil +} + // GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys. func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime) diff --git a/backend/internal/service/user_group_rate.go b/backend/internal/service/user_group_rate.go index 9eb5f067..3d221a25 100644 --- a/backend/internal/service/user_group_rate.go +++ b/backend/internal/service/user_group_rate.go @@ -2,6 +2,22 @@ package service import "context" +// UserGroupRateEntry 分组下用户专属倍率条目 +type UserGroupRateEntry struct { + UserID int64 `json:"user_id"` + UserName string `json:"user_name"` + UserEmail string `json:"user_email"` + UserNotes string `json:"user_notes"` + UserStatus string `json:"user_status"` + RateMultiplier float64 `json:"rate_multiplier"` +} + +// GroupRateMultiplierInput 批量设置分组倍率的输入条目 +type GroupRateMultiplierInput struct { + UserID int64 `json:"user_id"` + RateMultiplier float64 `json:"rate_multiplier"` +} + // UserGroupRateRepository 用户专属分组倍率仓储接口 // 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率 type UserGroupRateRepository interface { @@ -13,10 +29,16 @@ type UserGroupRateRepository interface { // 如果未设置专属倍率,返回 nil GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) + // GetByGroupID 获取指定分组下所有用户的专属倍率 + GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) + // SyncUserGroupRates 同步用户的分组专属倍率 // rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率 SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error + // SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组数据) + SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error + // DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用) DeleteByGroupID(ctx context.Context, groupID int64) error diff --git a/backend/internal/service/user_group_rate_resolver.go b/backend/internal/service/user_group_rate_resolver.go new file mode 100644 index 00000000..7f0ffb0f --- /dev/null +++ b/backend/internal/service/user_group_rate_resolver.go @@ -0,0 +1,103 @@ +package service + +import ( + "context" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + gocache "github.com/patrickmn/go-cache" + "golang.org/x/sync/singleflight" +) + +type userGroupRateResolver struct { + repo UserGroupRateRepository + cache *gocache.Cache + cacheTTL time.Duration + sf *singleflight.Group + logComponent string +} + +func newUserGroupRateResolver(repo UserGroupRateRepository, cache *gocache.Cache, cacheTTL time.Duration, sf *singleflight.Group, logComponent string) *userGroupRateResolver { + if cacheTTL <= 0 { + cacheTTL = defaultUserGroupRateCacheTTL + } + if cache == nil { + cache = gocache.New(cacheTTL, time.Minute) + } + if logComponent == "" { + logComponent = "service.gateway" + } + if sf == nil { + sf = &singleflight.Group{} + } + + return &userGroupRateResolver{ + repo: repo, + cache: cache, + cacheTTL: cacheTTL, + sf: sf, + logComponent: logComponent, + } +} + +func (r *userGroupRateResolver) Resolve(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { + if r == nil || userID <= 0 || groupID <= 0 { + return groupDefaultMultiplier + } + + key := fmt.Sprintf("%d:%d", userID, groupID) + if r.cache != nil { + if cached, ok := r.cache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier + } + } + } + if r.repo == nil { + return groupDefaultMultiplier + } + userGroupRateCacheMissTotal.Add(1) + + value, err, shared := r.sf.Do(key, func() (any, error) { + if r.cache != nil { + if cached, ok := r.cache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier, nil + } + } + } + + userGroupRateCacheLoadTotal.Add(1) + userRate, repoErr := r.repo.GetByUserAndGroup(ctx, userID, groupID) + if repoErr != nil { + return nil, repoErr + } + + multiplier := groupDefaultMultiplier + if userRate != nil { + multiplier = *userRate + } + if r.cache != nil { + r.cache.Set(key, multiplier, r.cacheTTL) + } + return multiplier, nil + }) + if shared { + userGroupRateCacheSFSharedTotal.Add(1) + } + if err != nil { + userGroupRateCacheFallbackTotal.Add(1) + logger.LegacyPrintf(r.logComponent, "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err) + return groupDefaultMultiplier + } + + multiplier, ok := value.(float64) + if !ok { + userGroupRateCacheFallbackTotal.Add(1) + return groupDefaultMultiplier + } + return multiplier +} diff --git a/backend/internal/service/user_group_rate_resolver_test.go b/backend/internal/service/user_group_rate_resolver_test.go new file mode 100644 index 00000000..064ef7ba --- /dev/null +++ b/backend/internal/service/user_group_rate_resolver_test.go @@ -0,0 +1,83 @@ +package service + +import ( + "context" + "testing" + "time" + + gocache "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/require" +) + +type userGroupRateResolverRepoStub struct { + UserGroupRateRepository + + rate *float64 + err error + calls int +} + +func (s *userGroupRateResolverRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + s.calls++ + if s.err != nil { + return nil, s.err + } + return s.rate, nil +} + +func TestNewUserGroupRateResolver_Defaults(t *testing.T) { + resolver := newUserGroupRateResolver(nil, nil, 0, nil, "") + + require.NotNil(t, resolver) + require.NotNil(t, resolver.cache) + require.Equal(t, defaultUserGroupRateCacheTTL, resolver.cacheTTL) + require.NotNil(t, resolver.sf) + require.Equal(t, "service.gateway", resolver.logComponent) +} + +func TestUserGroupRateResolverResolve_FallbackForNilResolverAndInvalidIDs(t *testing.T) { + var nilResolver *userGroupRateResolver + require.Equal(t, 1.4, nilResolver.Resolve(context.Background(), 101, 202, 1.4)) + + resolver := newUserGroupRateResolver(nil, nil, time.Second, nil, "service.test") + require.Equal(t, 1.4, resolver.Resolve(context.Background(), 0, 202, 1.4)) + require.Equal(t, 1.4, resolver.Resolve(context.Background(), 101, 0, 1.4)) +} + +func TestUserGroupRateResolverResolve_InvalidCacheEntryLoadsRepoAndCaches(t *testing.T) { + resetGatewayHotpathStatsForTest() + + rate := 1.7 + repo := &userGroupRateResolverRepoStub{rate: &rate} + cache := gocache.New(time.Minute, time.Minute) + cache.Set("101:202", "bad-cache", time.Minute) + resolver := newUserGroupRateResolver(repo, cache, time.Minute, nil, "service.test") + + got := resolver.Resolve(context.Background(), 101, 202, 1.2) + require.Equal(t, rate, got) + require.Equal(t, 1, repo.calls) + + cached, ok := cache.Get("101:202") + require.True(t, ok) + require.Equal(t, rate, cached) + + hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats() + require.Equal(t, int64(0), hit) + require.Equal(t, int64(1), miss) + require.Equal(t, int64(1), load) + require.Equal(t, int64(0), fallback) +} + +func TestGatewayServiceGetUserGroupRateMultiplier_FallbacksAndUsesExistingResolver(t *testing.T) { + var nilSvc *GatewayService + require.Equal(t, 1.3, nilSvc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.3)) + + rate := 1.9 + repo := &userGroupRateResolverRepoStub{rate: &rate} + resolver := newUserGroupRateResolver(repo, nil, time.Minute, nil, "service.gateway") + svc := &GatewayService{userGroupRateResolver: resolver} + + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2) + require.Equal(t, rate, got) + require.Equal(t, 1, repo.calls) +} diff --git a/backend/internal/service/user_msg_queue_service.go b/backend/internal/service/user_msg_queue_service.go new file mode 100644 index 00000000..a0ce95a8 --- /dev/null +++ b/backend/internal/service/user_msg_queue_service.go @@ -0,0 +1,318 @@ +package service + +import ( + "context" + cryptorand "crypto/rand" + "encoding/hex" + "fmt" + "math" + "math/rand/v2" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// UserMsgQueueCache 用户消息串行队列 Redis 缓存接口 +type UserMsgQueueCache interface { + // AcquireLock 尝试获取账号级串行锁 + AcquireLock(ctx context.Context, accountID int64, requestID string, lockTtlMs int) (acquired bool, err error) + // ReleaseLock 释放锁并记录完成时间 + ReleaseLock(ctx context.Context, accountID int64, requestID string) (released bool, err error) + // GetLastCompletedMs 获取上次完成时间(毫秒时间戳,Redis TIME 源) + GetLastCompletedMs(ctx context.Context, accountID int64) (int64, error) + // GetCurrentTimeMs 获取 Redis 服务器当前时间(毫秒),与 ReleaseLock 记录的时间源一致 + GetCurrentTimeMs(ctx context.Context) (int64, error) + // ForceReleaseLock 强制释放锁(孤儿锁清理) + ForceReleaseLock(ctx context.Context, accountID int64) error + // ScanLockKeys 扫描 PTTL == -1 的孤儿锁 key,返回 accountID 列表 + ScanLockKeys(ctx context.Context, maxCount int) ([]int64, error) +} + +// QueueLockResult 锁获取结果 +type QueueLockResult struct { + Acquired bool + RequestID string +} + +// UserMessageQueueService 用户消息串行队列服务 +// 对真实用户消息实施账号级串行化 + RPM 自适应延迟 +type UserMessageQueueService struct { + cache UserMsgQueueCache + rpmCache RPMCache + cfg *config.UserMessageQueueConfig + stopCh chan struct{} // graceful shutdown + stopOnce sync.Once // 确保 Stop() 并发安全 +} + +// NewUserMessageQueueService 创建用户消息串行队列服务 +func NewUserMessageQueueService(cache UserMsgQueueCache, rpmCache RPMCache, cfg *config.UserMessageQueueConfig) *UserMessageQueueService { + return &UserMessageQueueService{ + cache: cache, + rpmCache: rpmCache, + cfg: cfg, + stopCh: make(chan struct{}), + } +} + +// IsRealUserMessage 检测是否为真实用户消息(非 tool_result) +// 与 claude-relay-service 的检测逻辑一致: +// 1. messages 非空 +// 2. 最后一条消息 role == "user" +// 3. 最后一条消息 content(如果是数组)中不含 type:"tool_result" / "tool_use_result" +func IsRealUserMessage(parsed *ParsedRequest) bool { + if parsed == nil || len(parsed.Messages) == 0 { + return false + } + + lastMsg := parsed.Messages[len(parsed.Messages)-1] + msgMap, ok := lastMsg.(map[string]any) + if !ok { + return false + } + + role, _ := msgMap["role"].(string) + if role != "user" { + return false + } + + // 检查 content 是否包含 tool_result 类型 + content, ok := msgMap["content"] + if !ok { + return true // 没有 content 字段,视为普通用户消息 + } + + contentArr, ok := content.([]any) + if !ok { + return true // content 不是数组(可能是 string),视为普通用户消息 + } + + for _, item := range contentArr { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType == "tool_result" || itemType == "tool_use_result" { + return false + } + } + return true +} + +// TryAcquire 尝试立即获取串行锁 +func (s *UserMessageQueueService) TryAcquire(ctx context.Context, accountID int64) (*QueueLockResult, error) { + if s.cache == nil { + return &QueueLockResult{Acquired: true}, nil // fail-open + } + + requestID := generateUMQRequestID() + lockTTL := s.cfg.LockTTLMs + if lockTTL <= 0 { + lockTTL = 120000 + } + + acquired, err := s.cache.AcquireLock(ctx, accountID, requestID, lockTTL) + if err != nil { + logger.LegacyPrintf("service.umq", "AcquireLock failed for account %d: %v", accountID, err) + return &QueueLockResult{Acquired: true}, nil // fail-open + } + + return &QueueLockResult{ + Acquired: acquired, + RequestID: requestID, + }, nil +} + +// Release 释放串行锁 +func (s *UserMessageQueueService) Release(ctx context.Context, accountID int64, requestID string) error { + if s.cache == nil || requestID == "" { + return nil + } + released, err := s.cache.ReleaseLock(ctx, accountID, requestID) + if err != nil { + logger.LegacyPrintf("service.umq", "ReleaseLock failed for account %d: %v", accountID, err) + return err + } + if !released { + logger.LegacyPrintf("service.umq", "ReleaseLock no-op for account %d (requestID mismatch or expired)", accountID) + } + return nil +} + +// EnforceDelay 根据 RPM 负载执行自适应延迟 +// 使用 Redis TIME 确保与 releaseLockScript 记录的时间源一致 +func (s *UserMessageQueueService) EnforceDelay(ctx context.Context, accountID int64, baseRPM int) error { + if s.cache == nil { + return nil + } + + // 先检查历史记录:没有历史则无需延迟,避免不必要的 RPM 查询 + lastMs, err := s.cache.GetLastCompletedMs(ctx, accountID) + if err != nil { + logger.LegacyPrintf("service.umq", "GetLastCompletedMs failed for account %d: %v", accountID, err) + return nil // fail-open + } + if lastMs == 0 { + return nil // 没有历史记录,无需延迟 + } + + delay := s.CalculateRPMAwareDelay(ctx, accountID, baseRPM) + if delay <= 0 { + return nil + } + + // 获取 Redis 当前时间(与 lastMs 同源,避免时钟偏差) + nowMs, err := s.cache.GetCurrentTimeMs(ctx) + if err != nil { + logger.LegacyPrintf("service.umq", "GetCurrentTimeMs failed: %v", err) + return nil // fail-open + } + + elapsed := time.Duration(nowMs-lastMs) * time.Millisecond + if elapsed < 0 { + // 时钟异常(Redis 故障转移等),fail-open + return nil + } + remaining := delay - elapsed + if remaining <= 0 { + return nil + } + + // 执行延迟 + timer := time.NewTimer(remaining) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +// CalculateRPMAwareDelay 根据当前 RPM 负载计算自适应延迟 +// ratio = currentRPM / baseRPM +// ratio < 0.5 → MinDelay +// 0.5 ≤ ratio < 0.8 → 线性插值 MinDelay..MaxDelay +// ratio ≥ 0.8 → MaxDelay +// 返回值包含 ±15% 随机抖动(anti-detection + 避免惊群效应) +func (s *UserMessageQueueService) CalculateRPMAwareDelay(ctx context.Context, accountID int64, baseRPM int) time.Duration { + minDelay := time.Duration(s.cfg.MinDelayMs) * time.Millisecond + maxDelay := time.Duration(s.cfg.MaxDelayMs) * time.Millisecond + + if minDelay <= 0 { + minDelay = 200 * time.Millisecond + } + if maxDelay <= 0 { + maxDelay = 2000 * time.Millisecond + } + // 防止配置错误:minDelay > maxDelay 时交换 + if minDelay > maxDelay { + minDelay, maxDelay = maxDelay, minDelay + } + + var baseDelay time.Duration + + if baseRPM <= 0 || s.rpmCache == nil { + baseDelay = minDelay + } else { + currentRPM, err := s.rpmCache.GetRPM(ctx, accountID) + if err != nil { + logger.LegacyPrintf("service.umq", "GetRPM failed for account %d: %v", accountID, err) + baseDelay = minDelay // fail-open + } else { + ratio := float64(currentRPM) / float64(baseRPM) + if ratio < 0.5 { + baseDelay = minDelay + } else if ratio >= 0.8 { + baseDelay = maxDelay + } else { + // 线性插值: 0.5 → minDelay, 0.8 → maxDelay + t := (ratio - 0.5) / 0.3 + interpolated := float64(minDelay) + t*(float64(maxDelay)-float64(minDelay)) + baseDelay = time.Duration(math.Round(interpolated)) + } + } + } + + // ±15% 随机抖动 + return applyJitter(baseDelay, 0.15) +} + +// StartCleanupWorker 启动孤儿锁清理 worker +// 定期 SCAN umq:*:lock 并清理 PTTL == -1 的异常锁(PTTL 检查在 cache.ScanLockKeys 内完成) +func (s *UserMessageQueueService) StartCleanupWorker(interval time.Duration) { + if s == nil || s.cache == nil || interval <= 0 { + return + } + + runCleanup := func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + accountIDs, err := s.cache.ScanLockKeys(ctx, 1000) + if err != nil { + logger.LegacyPrintf("service.umq", "Cleanup scan failed: %v", err) + return + } + + cleaned := 0 + for _, accountID := range accountIDs { + cleanCtx, cleanCancel := context.WithTimeout(context.Background(), 2*time.Second) + if err := s.cache.ForceReleaseLock(cleanCtx, accountID); err != nil { + logger.LegacyPrintf("service.umq", "Cleanup force release failed for account %d: %v", accountID, err) + } else { + cleaned++ + } + cleanCancel() + } + + if cleaned > 0 { + logger.LegacyPrintf("service.umq", "Cleanup completed: released %d orphaned locks", cleaned) + } + } + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-s.stopCh: + return + case <-ticker.C: + runCleanup() + } + } + }() +} + +// Stop 停止后台 cleanup worker +func (s *UserMessageQueueService) Stop() { + if s != nil && s.stopCh != nil { + s.stopOnce.Do(func() { + close(s.stopCh) + }) + } +} + +// applyJitter 对延迟值施加 ±jitterPct 的随机抖动 +// 使用 math/rand/v2(Go 1.22+ 自动使用 crypto/rand 种子),与 nextBackoff 一致 +// 例如 applyJitter(200ms, 0.15) 返回 170ms ~ 230ms +func applyJitter(d time.Duration, jitterPct float64) time.Duration { + if d <= 0 || jitterPct <= 0 { + return d + } + // [-jitterPct, +jitterPct] + jitter := (rand.Float64()*2 - 1) * jitterPct + return time.Duration(float64(d) * (1 + jitter)) +} + +// generateUMQRequestID 生成唯一请求 ID(与 generateRequestID 一致的 fallback 模式) +func generateUMQRequestID() string { + b := make([]byte, 16) + if _, err := cryptorand.Read(b); err != nil { + return fmt.Sprintf("%x", time.Now().UnixNano()) + } + return hex.EncodeToString(b) +} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index b5553935..49ba3645 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -22,6 +22,10 @@ type UserListFilters struct { Role string // User role filter Search string // Search in email, username Attributes map[int64]string // Custom attribute filters: attributeID -> value + // IncludeSubscriptions controls whether ListWithFilters should load active subscriptions. + // For large datasets this can be expensive; admin list pages should enable it on demand. + // nil means not specified (default: load subscriptions for backward compatibility). + IncludeSubscriptions *bool } type UserRepository interface { diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 5ba2b99e..05fe5056 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -96,6 +96,18 @@ func (m *mockBillingCache) UpdateSubscriptionUsage(context.Context, int64, int64 func (m *mockBillingCache) InvalidateSubscriptionCache(context.Context, int64, int64) error { return nil } +func (m *mockBillingCache) GetAPIKeyRateLimit(context.Context, int64) (*APIKeyRateLimitCacheData, error) { + return nil, nil +} +func (m *mockBillingCache) SetAPIKeyRateLimit(context.Context, int64, *APIKeyRateLimitCacheData) error { + return nil +} +func (m *mockBillingCache) UpdateAPIKeyRateLimitUsage(context.Context, int64, float64) error { + return nil +} +func (m *mockBillingCache) InvalidateAPIKeyRateLimit(context.Context, int64) error { + return nil +} // --- 测试 --- diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index b0eccb71..7da72630 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -48,14 +48,80 @@ func ProvideTokenRefreshService( cacheInvalidator TokenCacheInvalidator, schedulerCache SchedulerCache, cfg *config.Config, + tempUnschedCache TempUnschedCache, + privacyClientFactory PrivacyClientFactory, + proxyRepo ProxyRepository, + refreshAPI *OAuthRefreshAPI, ) *TokenRefreshService { - svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg) + svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache) // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 svc.SetSoraAccountRepo(soraAccountRepo) + // 注入 OpenAI privacy opt-out 依赖 + svc.SetPrivacyDeps(privacyClientFactory, proxyRepo) + // 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件) + svc.SetRefreshAPI(refreshAPI) + // 调用侧显式注入后台刷新策略,避免策略漂移 + svc.SetRefreshPolicy(DefaultBackgroundRefreshPolicy()) svc.Start() return svc } +// ProvideClaudeTokenProvider creates ClaudeTokenProvider with OAuthRefreshAPI injection +func ProvideClaudeTokenProvider( + accountRepo AccountRepository, + tokenCache GeminiTokenCache, + oauthService *OAuthService, + refreshAPI *OAuthRefreshAPI, +) *ClaudeTokenProvider { + p := NewClaudeTokenProvider(accountRepo, tokenCache, oauthService) + executor := NewClaudeTokenRefresher(oauthService) + p.SetRefreshAPI(refreshAPI, executor) + p.SetRefreshPolicy(ClaudeProviderRefreshPolicy()) + return p +} + +// ProvideOpenAITokenProvider creates OpenAITokenProvider with OAuthRefreshAPI injection +func ProvideOpenAITokenProvider( + accountRepo AccountRepository, + tokenCache GeminiTokenCache, + openaiOAuthService *OpenAIOAuthService, + refreshAPI *OAuthRefreshAPI, +) *OpenAITokenProvider { + p := NewOpenAITokenProvider(accountRepo, tokenCache, openaiOAuthService) + executor := NewOpenAITokenRefresher(openaiOAuthService, accountRepo) + p.SetRefreshAPI(refreshAPI, executor) + p.SetRefreshPolicy(OpenAIProviderRefreshPolicy()) + return p +} + +// ProvideGeminiTokenProvider creates GeminiTokenProvider with OAuthRefreshAPI injection +func ProvideGeminiTokenProvider( + accountRepo AccountRepository, + tokenCache GeminiTokenCache, + geminiOAuthService *GeminiOAuthService, + refreshAPI *OAuthRefreshAPI, +) *GeminiTokenProvider { + p := NewGeminiTokenProvider(accountRepo, tokenCache, geminiOAuthService) + executor := NewGeminiTokenRefresher(geminiOAuthService) + p.SetRefreshAPI(refreshAPI, executor) + p.SetRefreshPolicy(GeminiProviderRefreshPolicy()) + return p +} + +// ProvideAntigravityTokenProvider creates AntigravityTokenProvider with OAuthRefreshAPI injection +func ProvideAntigravityTokenProvider( + accountRepo AccountRepository, + tokenCache GeminiTokenCache, + antigravityOAuthService *AntigravityOAuthService, + refreshAPI *OAuthRefreshAPI, +) *AntigravityTokenProvider { + p := NewAntigravityTokenProvider(accountRepo, tokenCache, antigravityOAuthService) + executor := NewAntigravityTokenRefresher(antigravityOAuthService) + p.SetRefreshAPI(refreshAPI, executor) + p.SetRefreshPolicy(AntigravityProviderRefreshPolicy()) + return p +} + // ProvideDashboardAggregationService 创建并启动仪表盘聚合服务 func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService { svc := NewDashboardAggregationService(repo, timingWheel, cfg) @@ -104,12 +170,24 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh // ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker. func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService { svc := NewConcurrencyService(cache) + if err := svc.CleanupStaleProcessSlots(context.Background()); err != nil { + logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err) + } if cfg != nil { svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval) } return svc } +// ProvideUserMessageQueueService 创建用户消息串行队列服务并启动清理 worker +func ProvideUserMessageQueueService(cache UserMsgQueueCache, rpmCache RPMCache, cfg *config.Config) *UserMessageQueueService { + svc := NewUserMessageQueueService(cache, rpmCache, &cfg.Gateway.UserMessageQueue) + if cfg.Gateway.UserMessageQueue.CleanupIntervalSeconds > 0 { + svc.StartCleanupWorker(time.Duration(cfg.Gateway.UserMessageQueue.CleanupIntervalSeconds) * time.Second) + } + return svc +} + // ProvideSchedulerSnapshotService creates and starts SchedulerSnapshotService. func ProvideSchedulerSnapshotService( cache SchedulerCache, @@ -264,6 +342,27 @@ func ProvideIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Co return svc } +// ProvideScheduledTestService creates ScheduledTestService. +func ProvideScheduledTestService( + planRepo ScheduledTestPlanRepository, + resultRepo ScheduledTestResultRepository, +) *ScheduledTestService { + return NewScheduledTestService(planRepo, resultRepo) +} + +// ProvideScheduledTestRunnerService creates and starts ScheduledTestRunnerService. +func ProvideScheduledTestRunnerService( + planRepo ScheduledTestPlanRepository, + scheduledSvc *ScheduledTestService, + accountTestSvc *AccountTestService, + rateLimitSvc *RateLimitService, + cfg *config.Config, +) *ScheduledTestRunnerService { + svc := NewScheduledTestRunnerService(planRepo, scheduledSvc, accountTestSvc, rateLimitSvc, cfg) + svc.Start() + return svc +} + // ProvideOpsScheduledReportService creates and starts OpsScheduledReportService. func ProvideOpsScheduledReportService( opsService *OpsService, @@ -284,6 +383,19 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC return apiKeyService } +// ProvideBackupService creates and starts BackupService +func ProvideBackupService( + settingRepo SettingRepository, + cfg *config.Config, + encryptor SecretEncryptor, + storeFactory BackupObjectStoreFactory, + dumper DBDumper, +) *BackupService { + svc := NewBackupService(settingRepo, cfg, encryptor, storeFactory, dumper) + svc.Start() + return svc +} + // ProvideSettingService wires SettingService with group reader for default subscription validation. func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService { svc := NewSettingService(settingRepo, cfg) @@ -324,17 +436,19 @@ var ProviderSet = wire.NewSet( NewCompositeTokenCacheInvalidator, wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)), NewAntigravityOAuthService, - NewGeminiTokenProvider, + NewOAuthRefreshAPI, + ProvideGeminiTokenProvider, NewGeminiMessagesCompatService, - NewAntigravityTokenProvider, - NewOpenAITokenProvider, - NewClaudeTokenProvider, + ProvideAntigravityTokenProvider, + ProvideOpenAITokenProvider, + ProvideClaudeTokenProvider, NewAntigravityGatewayService, ProvideRateLimitService, NewAccountUsageService, NewAccountTestService, ProvideSettingService, NewDataManagementService, + ProvideBackupService, ProvideOpsSystemLogSink, NewOpsService, ProvideOpsMetricsCollector, @@ -348,6 +462,7 @@ var ProviderSet = wire.NewSet( NewSubscriptionService, wire.Bind(new(DefaultSubscriptionAssigner), new(*SubscriptionService)), ProvideConcurrencyService, + ProvideUserMessageQueueService, NewUsageRecordWorkerPool, ProvideSchedulerSnapshotService, NewIdentityService, @@ -369,4 +484,6 @@ var ProviderSet = wire.NewSet( ProvideIdempotencyCoordinator, ProvideSystemOperationLockService, ProvideIdempotencyCleanupService, + ProvideScheduledTestService, + ProvideScheduledTestRunnerService, ) diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index 83c32db3..de3b765a 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/service" @@ -23,10 +24,19 @@ import ( // Config paths const ( - ConfigFileName = "config.yaml" - InstallLockFile = ".installed" + ConfigFileName = "config.yaml" + InstallLockFile = ".installed" + defaultUserConcurrency = 5 + simpleModeAdminConcurrency = 30 ) +func setupDefaultAdminConcurrency() int { + if strings.EqualFold(strings.TrimSpace(os.Getenv("RUN_MODE")), config.RunModeSimple) { + return simpleModeAdminConcurrency + } + return defaultUserConcurrency +} + // GetDataDir returns the data directory for storing config and lock files. // Priority: DATA_DIR env > /app/data (if exists and writable) > current directory func GetDataDir() string { @@ -390,7 +400,7 @@ func createAdminUser(cfg *SetupConfig) (bool, string, error) { Role: service.RoleAdmin, Status: service.StatusActive, Balance: 0, - Concurrency: 5, + Concurrency: setupDefaultAdminConcurrency(), CreatedAt: time.Now(), UpdatedAt: time.Now(), } @@ -462,7 +472,7 @@ func writeConfigFile(cfg *SetupConfig) error { APIKeyPrefix string `yaml:"api_key_prefix"` RateMultiplier float64 `yaml:"rate_multiplier"` }{ - UserConcurrency: 5, + UserConcurrency: defaultUserConcurrency, UserBalance: 0, APIKeyPrefix: "sk-", RateMultiplier: 1.0, diff --git a/backend/internal/setup/setup_test.go b/backend/internal/setup/setup_test.go index 69655e92..a01dd00c 100644 --- a/backend/internal/setup/setup_test.go +++ b/backend/internal/setup/setup_test.go @@ -1,6 +1,10 @@ package setup -import "testing" +import ( + "os" + "strings" + "testing" +) func TestDecideAdminBootstrap(t *testing.T) { t.Parallel() @@ -49,3 +53,37 @@ func TestDecideAdminBootstrap(t *testing.T) { }) } } + +func TestSetupDefaultAdminConcurrency(t *testing.T) { + t.Run("simple mode admin uses higher concurrency", func(t *testing.T) { + t.Setenv("RUN_MODE", "simple") + if got := setupDefaultAdminConcurrency(); got != simpleModeAdminConcurrency { + t.Fatalf("setupDefaultAdminConcurrency()=%d, want %d", got, simpleModeAdminConcurrency) + } + }) + + t.Run("standard mode keeps existing default", func(t *testing.T) { + t.Setenv("RUN_MODE", "standard") + if got := setupDefaultAdminConcurrency(); got != defaultUserConcurrency { + t.Fatalf("setupDefaultAdminConcurrency()=%d, want %d", got, defaultUserConcurrency) + } + }) +} + +func TestWriteConfigFileKeepsDefaultUserConcurrency(t *testing.T) { + t.Setenv("RUN_MODE", "simple") + t.Setenv("DATA_DIR", t.TempDir()) + + if err := writeConfigFile(&SetupConfig{}); err != nil { + t.Fatalf("writeConfigFile() error = %v", err) + } + + data, err := os.ReadFile(GetConfigFilePath()) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + + if !strings.Contains(string(data), "user_concurrency: 5") { + t.Fatalf("config missing default user concurrency, got:\n%s", string(data)) + } +} diff --git a/backend/internal/testutil/stubs.go b/backend/internal/testutil/stubs.go index 217a5f56..bc572e11 100644 --- a/backend/internal/testutil/stubs.go +++ b/backend/internal/testutil/stubs.go @@ -76,6 +76,9 @@ func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acco func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { return nil } +func (c StubConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, _ string) error { + return nil +} // ============================================================ // StubGatewayCache — service.GatewayCache 的空实现 diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index f7ba5c9e..41ce4d48 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -83,14 +83,7 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc { path := c.Request.URL.Path // Skip API routes - if strings.HasPrefix(path, "/api/") || - strings.HasPrefix(path, "/v1/") || - strings.HasPrefix(path, "/v1beta/") || - strings.HasPrefix(path, "/sora/") || - strings.HasPrefix(path, "/antigravity/") || - strings.HasPrefix(path, "/setup/") || - path == "/health" || - path == "/responses" { + if shouldBypassEmbeddedFrontend(path) { c.Next() return } @@ -207,14 +200,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { return func(c *gin.Context) { path := c.Request.URL.Path - if strings.HasPrefix(path, "/api/") || - strings.HasPrefix(path, "/v1/") || - strings.HasPrefix(path, "/v1beta/") || - strings.HasPrefix(path, "/sora/") || - strings.HasPrefix(path, "/antigravity/") || - strings.HasPrefix(path, "/setup/") || - path == "/health" || - path == "/responses" { + if shouldBypassEmbeddedFrontend(path) { c.Next() return } @@ -235,6 +221,19 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { } } +func shouldBypassEmbeddedFrontend(path string) bool { + trimmed := strings.TrimSpace(path) + return strings.HasPrefix(trimmed, "/api/") || + strings.HasPrefix(trimmed, "/v1/") || + strings.HasPrefix(trimmed, "/v1beta/") || + strings.HasPrefix(trimmed, "/sora/") || + strings.HasPrefix(trimmed, "/antigravity/") || + strings.HasPrefix(trimmed, "/setup/") || + trimmed == "/health" || + trimmed == "/responses" || + strings.HasPrefix(trimmed, "/responses/") +} + func serveIndexHTML(c *gin.Context, fsys fs.FS) { file, err := fsys.Open("index.html") if err != nil { diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go index e2cbcf15..f270b624 100644 --- a/backend/internal/web/embed_test.go +++ b/backend/internal/web/embed_test.go @@ -367,6 +367,7 @@ func TestFrontendServer_Middleware(t *testing.T) { "/setup/init", "/health", "/responses", + "/responses/compact", } for _, path := range apiPaths { @@ -388,6 +389,32 @@ func TestFrontendServer_Middleware(t *testing.T) { } }) + t.Run("skips_responses_compact_post_routes", func(t *testing.T) { + provider := &mockSettingsProvider{ + settings: map[string]string{"test": "value"}, + } + + server, err := NewFrontendServer(provider) + require.NoError(t, err) + + router := gin.New() + router.Use(server.Middleware()) + nextCalled := false + router.POST("/responses/compact", func(c *gin.Context) { + nextCalled = true + c.String(http.StatusOK, `{"ok":true}`) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/responses/compact", strings.NewReader(`{"model":"gpt-5"}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + assert.True(t, nextCalled, "next handler should be called for compact API route") + assert.Equal(t, http.StatusOK, w.Code) + assert.JSONEq(t, `{"ok":true}`, w.Body.String()) + }) + t.Run("serves_index_for_spa_routes", func(t *testing.T) { provider := &mockSettingsProvider{ settings: map[string]string{"test": "value"}, @@ -543,6 +570,7 @@ func TestServeEmbeddedFrontend(t *testing.T) { "/setup/init", "/health", "/responses", + "/responses/compact", } for _, path := range apiPaths { diff --git a/backend/migrations/061_add_usage_log_request_type.sql b/backend/migrations/061_add_usage_log_request_type.sql index 68a33d51..d2a9f446 100644 --- a/backend/migrations/061_add_usage_log_request_type.sql +++ b/backend/migrations/061_add_usage_log_request_type.sql @@ -19,11 +19,47 @@ $$; CREATE INDEX IF NOT EXISTS idx_usage_logs_request_type_created_at ON usage_logs (request_type, created_at); --- Backfill from legacy fields. openai_ws_mode has higher priority than stream. -UPDATE usage_logs -SET request_type = CASE - WHEN openai_ws_mode = TRUE THEN 3 - WHEN stream = TRUE THEN 2 - ELSE 1 +-- Backfill from legacy fields in bounded batches. +-- Why bounded: +-- 1) Full-table UPDATE on large usage_logs can block startup for a long time. +-- 2) request_type=0 rows remain query-compatible via legacy fallback logic +-- (stream/openai_ws_mode) in repository filters. +-- 3) Subsequent writes will use explicit request_type and gradually dilute +-- historical unknown rows. +-- +-- openai_ws_mode has higher priority than stream. +DO $$ +DECLARE + v_rows INTEGER := 0; + v_total_rows INTEGER := 0; + v_batch_size INTEGER := 5000; + v_started_at TIMESTAMPTZ := clock_timestamp(); + v_max_duration INTERVAL := INTERVAL '8 seconds'; +BEGIN + LOOP + WITH batch AS ( + SELECT id + FROM usage_logs + WHERE request_type = 0 + ORDER BY id + LIMIT v_batch_size + ) + UPDATE usage_logs ul + SET request_type = CASE + WHEN ul.openai_ws_mode = TRUE THEN 3 + WHEN ul.stream = TRUE THEN 2 + ELSE 1 + END + FROM batch + WHERE ul.id = batch.id; + + GET DIAGNOSTICS v_rows = ROW_COUNT; + EXIT WHEN v_rows = 0; + + v_total_rows := v_total_rows + v_rows; + EXIT WHEN clock_timestamp() - v_started_at >= v_max_duration; + END LOOP; + + RAISE NOTICE 'usage_logs.request_type startup backfill rows=%', v_total_rows; END -WHERE request_type = 0; +$$; diff --git a/backend/migrations/064_add_api_key_rate_limits.sql b/backend/migrations/064_add_api_key_rate_limits.sql new file mode 100644 index 00000000..9e310f1d --- /dev/null +++ b/backend/migrations/064_add_api_key_rate_limits.sql @@ -0,0 +1,15 @@ +-- Add rate limit fields to api_keys table +-- Rate limit configuration (0 = unlimited) +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS rate_limit_5h decimal(20,8) NOT NULL DEFAULT 0; +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS rate_limit_1d decimal(20,8) NOT NULL DEFAULT 0; +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS rate_limit_7d decimal(20,8) NOT NULL DEFAULT 0; + +-- Rate limit usage tracking +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS usage_5h decimal(20,8) NOT NULL DEFAULT 0; +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS usage_1d decimal(20,8) NOT NULL DEFAULT 0; +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS usage_7d decimal(20,8) NOT NULL DEFAULT 0; + +-- Window start times (nullable) +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS window_5h_start timestamptz; +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS window_1d_start timestamptz; +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS window_7d_start timestamptz; diff --git a/backend/migrations/065_add_search_trgm_indexes.sql b/backend/migrations/065_add_search_trgm_indexes.sql new file mode 100644 index 00000000..f5efb5da --- /dev/null +++ b/backend/migrations/065_add_search_trgm_indexes.sql @@ -0,0 +1,33 @@ +-- Improve admin fuzzy-search performance on large datasets. +-- Best effort: +-- 1) try enabling pg_trgm +-- 2) only create trigram indexes when extension is available +DO $$ +BEGIN + BEGIN + CREATE EXTENSION IF NOT EXISTS pg_trgm; + EXCEPTION + WHEN OTHERS THEN + RAISE NOTICE 'pg_trgm extension not created: %', SQLERRM; + END; + + IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_trgm') THEN + EXECUTE 'CREATE INDEX IF NOT EXISTS idx_users_email_trgm + ON users USING gin (email gin_trgm_ops)'; + EXECUTE 'CREATE INDEX IF NOT EXISTS idx_users_username_trgm + ON users USING gin (username gin_trgm_ops)'; + EXECUTE 'CREATE INDEX IF NOT EXISTS idx_users_notes_trgm + ON users USING gin (notes gin_trgm_ops)'; + + EXECUTE 'CREATE INDEX IF NOT EXISTS idx_accounts_name_trgm + ON accounts USING gin (name gin_trgm_ops)'; + + EXECUTE 'CREATE INDEX IF NOT EXISTS idx_api_keys_key_trgm + ON api_keys USING gin ("key" gin_trgm_ops)'; + EXECUTE 'CREATE INDEX IF NOT EXISTS idx_api_keys_name_trgm + ON api_keys USING gin (name gin_trgm_ops)'; + ELSE + RAISE NOTICE 'skip trigram indexes because pg_trgm is unavailable'; + END IF; +END +$$; diff --git a/backend/migrations/066_add_scheduled_test_tables.sql b/backend/migrations/066_add_scheduled_test_tables.sql new file mode 100644 index 00000000..a9f839c0 --- /dev/null +++ b/backend/migrations/066_add_scheduled_test_tables.sql @@ -0,0 +1,30 @@ +-- 066_add_scheduled_test_tables.sql +-- Scheduled account test plans and results + +CREATE TABLE IF NOT EXISTS scheduled_test_plans ( + id BIGSERIAL PRIMARY KEY, + account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE, + model_id VARCHAR(100) NOT NULL DEFAULT '', + cron_expression VARCHAR(100) NOT NULL DEFAULT '*/30 * * * *', + enabled BOOLEAN NOT NULL DEFAULT true, + max_results INT NOT NULL DEFAULT 50, + last_run_at TIMESTAMPTZ, + next_run_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_stp_account_id ON scheduled_test_plans(account_id); +CREATE INDEX IF NOT EXISTS idx_stp_enabled_next_run ON scheduled_test_plans(enabled, next_run_at) WHERE enabled = true; + +CREATE TABLE IF NOT EXISTS scheduled_test_results ( + id BIGSERIAL PRIMARY KEY, + plan_id BIGINT NOT NULL REFERENCES scheduled_test_plans(id) ON DELETE CASCADE, + status VARCHAR(20) NOT NULL DEFAULT 'success', + response_text TEXT NOT NULL DEFAULT '', + error_message TEXT NOT NULL DEFAULT '', + latency_ms BIGINT NOT NULL DEFAULT 0, + started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + finished_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_str_plan_created ON scheduled_test_results(plan_id, created_at DESC); diff --git a/backend/migrations/067_add_account_load_factor.sql b/backend/migrations/067_add_account_load_factor.sql new file mode 100644 index 00000000..6805e8c2 --- /dev/null +++ b/backend/migrations/067_add_account_load_factor.sql @@ -0,0 +1 @@ +ALTER TABLE accounts ADD COLUMN IF NOT EXISTS load_factor INTEGER; diff --git a/backend/migrations/068_add_announcement_notify_mode.sql b/backend/migrations/068_add_announcement_notify_mode.sql new file mode 100644 index 00000000..28deb983 --- /dev/null +++ b/backend/migrations/068_add_announcement_notify_mode.sql @@ -0,0 +1 @@ +ALTER TABLE announcements ADD COLUMN IF NOT EXISTS notify_mode VARCHAR(20) NOT NULL DEFAULT 'silent'; diff --git a/backend/migrations/069_add_group_messages_dispatch.sql b/backend/migrations/069_add_group_messages_dispatch.sql new file mode 100644 index 00000000..7b9d5f5d --- /dev/null +++ b/backend/migrations/069_add_group_messages_dispatch.sql @@ -0,0 +1,2 @@ +ALTER TABLE groups ADD COLUMN allow_messages_dispatch BOOLEAN NOT NULL DEFAULT false; +ALTER TABLE groups ADD COLUMN default_mapped_model VARCHAR(100) NOT NULL DEFAULT ''; diff --git a/backend/migrations/070_add_scheduled_test_auto_recover.sql b/backend/migrations/070_add_scheduled_test_auto_recover.sql new file mode 100644 index 00000000..5f0c6789 --- /dev/null +++ b/backend/migrations/070_add_scheduled_test_auto_recover.sql @@ -0,0 +1,4 @@ +-- 070: Add auto_recover column to scheduled_test_plans +-- When enabled, automatically recovers account from error/rate-limited state on successful test + +ALTER TABLE scheduled_test_plans ADD COLUMN IF NOT EXISTS auto_recover BOOLEAN NOT NULL DEFAULT false; diff --git a/backend/migrations/070_add_usage_log_service_tier.sql b/backend/migrations/070_add_usage_log_service_tier.sql new file mode 100644 index 00000000..085ec0d6 --- /dev/null +++ b/backend/migrations/070_add_usage_log_service_tier.sql @@ -0,0 +1,5 @@ +ALTER TABLE usage_logs + ADD COLUMN IF NOT EXISTS service_tier VARCHAR(16); + +CREATE INDEX IF NOT EXISTS idx_usage_logs_service_tier_created_at + ON usage_logs (service_tier, created_at); diff --git a/backend/migrations/071_add_gemini25_flash_image_to_model_mapping.sql b/backend/migrations/071_add_gemini25_flash_image_to_model_mapping.sql new file mode 100644 index 00000000..f3cb3d37 --- /dev/null +++ b/backend/migrations/071_add_gemini25_flash_image_to_model_mapping.sql @@ -0,0 +1,51 @@ +-- Add gemini-2.5-flash-image aliases to Antigravity model_mapping +-- +-- Background: +-- Gemini native image generation now relies on gemini-2.5-flash-image, and +-- existing Antigravity accounts with persisted model_mapping need this alias in +-- order to participate in mixed scheduling from gemini groups. +-- +-- Strategy: +-- Overwrite the stored model_mapping so it matches DefaultAntigravityModelMapping +-- in constants.go, including legacy gemini-3-pro-image aliases. + +UPDATE accounts +SET credentials = jsonb_set( + credentials, + '{model_mapping}', + '{ + "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-6": "claude-opus-4-6-thinking", + "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", + "claude-sonnet-4-6": "claude-sonnet-4-6", + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", + "claude-haiku-4-5": "claude-sonnet-4-5", + "claude-haiku-4-5-20251001": "claude-sonnet-4-5", + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-image": "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview": "gemini-2.5-flash-image", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-3-flash": "gemini-3-flash", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", + "gemini-3-flash-preview": "gemini-3-flash", + "gemini-3-pro-preview": "gemini-3-pro-high", + "gemini-3.1-pro-high": "gemini-3.1-pro-high", + "gemini-3.1-pro-low": "gemini-3.1-pro-low", + "gemini-3.1-pro-preview": "gemini-3.1-pro-high", + "gemini-3.1-flash-image": "gemini-3.1-flash-image", + "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", + "gemini-3-pro-image": "gemini-3.1-flash-image", + "gemini-3-pro-image-preview": "gemini-3.1-flash-image", + "gpt-oss-120b-medium": "gpt-oss-120b-medium", + "tab_flash_lite_preview": "tab_flash_lite_preview" + }'::jsonb +) +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping' IS NOT NULL; diff --git a/backend/migrations/071_add_usage_billing_dedup.sql b/backend/migrations/071_add_usage_billing_dedup.sql new file mode 100644 index 00000000..acc28459 --- /dev/null +++ b/backend/migrations/071_add_usage_billing_dedup.sql @@ -0,0 +1,13 @@ +-- 窄表账务幂等键:将“是否已扣费”从 usage_logs 解耦出来 +-- 幂等执行:可重复运行 + +CREATE TABLE IF NOT EXISTS usage_billing_dedup ( + id BIGSERIAL PRIMARY KEY, + request_id VARCHAR(255) NOT NULL, + api_key_id BIGINT NOT NULL, + request_fingerprint VARCHAR(64) NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_billing_dedup_request_api_key + ON usage_billing_dedup (request_id, api_key_id); diff --git a/backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql b/backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql new file mode 100644 index 00000000..965a3412 --- /dev/null +++ b/backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql @@ -0,0 +1,7 @@ +-- usage_billing_dedup 是按时间追加写入的幂等窄表。 +-- 使用 BRIN 支撑按 created_at 的批量保留期清理,尽量降低写放大。 +-- 使用 CONCURRENTLY 避免在热表上长时间阻塞写入。 + +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_billing_dedup_created_at_brin + ON usage_billing_dedup + USING BRIN (created_at); diff --git a/backend/migrations/073_add_usage_billing_dedup_archive.sql b/backend/migrations/073_add_usage_billing_dedup_archive.sql new file mode 100644 index 00000000..d156d4eb --- /dev/null +++ b/backend/migrations/073_add_usage_billing_dedup_archive.sql @@ -0,0 +1,10 @@ +-- 冷归档旧账务幂等键,缩小热表索引与清理范围,同时不丢失长期去重能力。 + +CREATE TABLE IF NOT EXISTS usage_billing_dedup_archive ( + request_id VARCHAR(255) NOT NULL, + api_key_id BIGINT NOT NULL, + request_fingerprint VARCHAR(64) NOT NULL, + created_at TIMESTAMPTZ NOT NULL, + archived_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (request_id, api_key_id) +); diff --git a/backend/migrations/074_add_usage_log_endpoints.sql b/backend/migrations/074_add_usage_log_endpoints.sql new file mode 100644 index 00000000..2a34e7c3 --- /dev/null +++ b/backend/migrations/074_add_usage_log_endpoints.sql @@ -0,0 +1,5 @@ +-- Add endpoint tracking fields to usage_logs. +-- inbound_endpoint: client-facing API route (e.g. /v1/chat/completions, /v1/messages, /v1/responses) +-- upstream_endpoint: normalized upstream route (e.g. /v1/responses) +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS inbound_endpoint VARCHAR(128); +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS upstream_endpoint VARCHAR(128); diff --git a/backend/resources/model-pricing/model_prices_and_context_window.json b/backend/resources/model-pricing/model_prices_and_context_window.json index 650e128e..72860bf9 100644 --- a/backend/resources/model-pricing/model_prices_and_context_window.json +++ b/backend/resources/model-pricing/model_prices_and_context_window.json @@ -5140,6 +5140,39 @@ "supports_vision": true, "supports_web_search": true }, + "gpt-5.4": { + "cache_read_input_token_cost": 2.5e-07, + "input_cost_per_token": 2.5e-06, + "litellm_provider": "openai", + "max_input_tokens": 1050000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 1.5e-05, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text", + "image" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_service_tier": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true + }, "gpt-5.3-codex": { "cache_read_input_token_cost": 1.75e-07, "cache_read_input_token_cost_priority": 3.5e-07, diff --git a/build_image.sh b/build_image.sh deleted file mode 100755 index f716e984..00000000 --- a/build_image.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env bash -# 本地构建镜像的快速脚本,避免在命令行反复输入构建参数。 - -set -euo pipefail - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - -docker build -t sub2api:latest \ - --build-arg GOPROXY=https://goproxy.cn,direct \ - --build-arg GOSUMDB=sum.golang.google.cn \ - -f "${SCRIPT_DIR}/Dockerfile" \ - "${SCRIPT_DIR}" diff --git a/deploy/.env.example b/deploy/.env.example index 9f2ff13e..e1eb8256 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -112,7 +112,7 @@ POSTGRES_DB=sub2api DATABASE_PORT=5432 # ----------------------------------------------------------------------------- -# PostgreSQL 服务端参数(可选;主要用于 deploy/docker-compose-aicodex.yml) +# PostgreSQL 服务端参数(可选) # ----------------------------------------------------------------------------- # POSTGRES_MAX_CONNECTIONS:PostgreSQL 服务端允许的最大连接数。 # 必须 >=(所有 Sub2API 实例的 DATABASE_MAX_OPEN_CONNS 之和)+ 预留余量(例如 20%)。 @@ -163,7 +163,7 @@ REDIS_PORT=6379 # Leave empty for no password (default for local development) REDIS_PASSWORD= REDIS_DB=0 -# Redis 服务端最大客户端连接数(可选;主要用于 deploy/docker-compose-aicodex.yml) +# Redis 服务端最大客户端连接数(可选) REDIS_MAXCLIENTS=50000 # Redis 连接池大小(默认 1024) REDIS_POOL_SIZE=4096 diff --git a/deploy/Dockerfile b/deploy/Dockerfile index b3320300..0f4f1de9 100644 --- a/deploy/Dockerfile +++ b/deploy/Dockerfile @@ -7,7 +7,7 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.25.5-alpine +ARG GOLANG_IMAGE=golang:1.26.1-alpine ARG ALPINE_IMAGE=alpine:3.20 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn @@ -105,7 +105,7 @@ EXPOSE 8080 # Health check HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ - CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1 + CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1 # Run the application ENTRYPOINT ["/app/sub2api"] diff --git a/deploy/build_image.sh b/deploy/build_image.sh old mode 100755 new mode 100644 diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index faa85854..2058ced1 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -134,6 +134,12 @@ security: # Allow skipping TLS verification for proxy probe (debug only) # 允许代理探测时跳过 TLS 证书验证(仅用于调试) insecure_skip_verify: false + proxy_fallback: + # Allow auxiliary services (update check, pricing data) to fallback to direct + # connection when proxy initialization fails. Does NOT affect AI gateway connections. + # 辅助服务(更新检查、定价数据拉取)代理初始化失败时是否允许回退直连。 + # 不影响 AI 账号网关连接。默认 false:fail-fast 防止 IP 泄露。 + allow_direct_on_error: false # ============================================================================= # Gateway Configuration @@ -203,8 +209,9 @@ gateway: openai_ws: # 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。 mode_router_v2_enabled: false - # ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效) - ingress_mode_default: shared + # ingress 默认模式:off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效) + # 兼容旧值:shared/dedicated 会按 ctx_pool 处理。 + ingress_mode_default: ctx_pool # 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由 enabled: true # 按账号类型细分开关 diff --git a/deploy/docker-compose-test.yml b/deploy/docker-compose-test.yml deleted file mode 100644 index 4c7ec144..00000000 --- a/deploy/docker-compose-test.yml +++ /dev/null @@ -1,212 +0,0 @@ -# ============================================================================= -# Sub2API Docker Compose Test Configuration (Local Build) -# ============================================================================= -# Quick Start: -# 1. Copy .env.example to .env and configure -# 2. docker-compose -f docker-compose-test.yml up -d --build -# 3. Check logs: docker-compose -f docker-compose-test.yml logs -f sub2api -# 4. Access: http://localhost:8080 -# -# This configuration builds the image from source (Dockerfile in project root). -# All configuration is done via environment variables. -# No Setup Wizard needed - the system auto-initializes on first run. -# ============================================================================= - -services: - # =========================================================================== - # Sub2API Application - # =========================================================================== - sub2api: - image: sub2api:latest - build: - context: .. - dockerfile: Dockerfile - container_name: sub2api - restart: unless-stopped - ulimits: - nofile: - soft: 100000 - hard: 100000 - ports: - - "${BIND_HOST:-0.0.0.0}:${SERVER_PORT:-8080}:8080" - volumes: - # Data persistence (config.yaml will be auto-generated here) - - sub2api_data:/app/data - # Mount custom config.yaml (optional, overrides auto-generated config) - # - ./config.yaml:/app/data/config.yaml:ro - environment: - # ======================================================================= - # Auto Setup (REQUIRED for Docker deployment) - # ======================================================================= - - AUTO_SETUP=true - - # ======================================================================= - # Server Configuration - # ======================================================================= - - SERVER_HOST=0.0.0.0 - - SERVER_PORT=8080 - - SERVER_MODE=${SERVER_MODE:-release} - - RUN_MODE=${RUN_MODE:-standard} - - # ======================================================================= - # Database Configuration (PostgreSQL) - # ======================================================================= - - DATABASE_HOST=postgres - - DATABASE_PORT=5432 - - DATABASE_USER=${POSTGRES_USER:-sub2api} - - DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} - - DATABASE_DBNAME=${POSTGRES_DB:-sub2api} - - DATABASE_SSLMODE=disable - - DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50} - - DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10} - - DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30} - - DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5} - - # ======================================================================= - # Redis Configuration - # ======================================================================= - - REDIS_HOST=redis - - REDIS_PORT=6379 - - REDIS_PASSWORD=${REDIS_PASSWORD:-} - - REDIS_DB=${REDIS_DB:-0} - - REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024} - - REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10} - - # ======================================================================= - # Admin Account (auto-created on first run) - # ======================================================================= - - ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local} - - ADMIN_PASSWORD=${ADMIN_PASSWORD:-} - - # ======================================================================= - # JWT Configuration - # ======================================================================= - # Leave empty to auto-generate (recommended) - - JWT_SECRET=${JWT_SECRET:-} - - JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24} - - # ======================================================================= - # Timezone Configuration - # This affects ALL time operations in the application: - # - Database timestamps - # - Usage statistics "today" boundary - # - Subscription expiry times - # - Log timestamps - # Common values: Asia/Shanghai, America/New_York, Europe/London, UTC - # ======================================================================= - - TZ=${TZ:-Asia/Shanghai} - - # ======================================================================= - # Gemini OAuth Configuration (for Gemini accounts) - # ======================================================================= - - GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-} - - GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-} - - GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-} - - GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-} - - # Built-in OAuth client secrets (optional) - # SECURITY: This repo does not embed third-party client_secret. - - GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-} - - ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-} - - # ======================================================================= - # Security Configuration (URL Allowlist) - # ======================================================================= - # Allow private IP addresses for CRS sync (for internal deployments) - - SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=${SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS:-true} - depends_on: - postgres: - condition: service_healthy - redis: - condition: service_healthy - networks: - - sub2api-network - healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8080/health"] - interval: 30s - timeout: 10s - retries: 3 - start_period: 30s - - # =========================================================================== - # PostgreSQL Database - # =========================================================================== - postgres: - image: postgres:18-alpine - container_name: sub2api-postgres - restart: unless-stopped - ulimits: - nofile: - soft: 100000 - hard: 100000 - volumes: - - postgres_data:/var/lib/postgresql/data - environment: - # postgres:18-alpine 默认 PGDATA=/var/lib/postgresql/18/docker(位于镜像声明的匿名卷 /var/lib/postgresql 内)。 - # 若不显式设置 PGDATA,则即使挂载了 postgres_data 到 /var/lib/postgresql/data,数据也不会落盘到该命名卷, - # docker compose down/up 后会触发 initdb 重新初始化,导致用户/密码等数据丢失。 - - PGDATA=/var/lib/postgresql/data - - POSTGRES_USER=${POSTGRES_USER:-sub2api} - - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} - - POSTGRES_DB=${POSTGRES_DB:-sub2api} - - TZ=${TZ:-Asia/Shanghai} - networks: - - sub2api-network - healthcheck: - test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"] - interval: 10s - timeout: 5s - retries: 5 - start_period: 10s - # 注意:不暴露端口到宿主机,应用通过内部网络连接 - # 如需调试,可临时添加:ports: ["127.0.0.1:5433:5432"] - - # =========================================================================== - # Redis Cache - # =========================================================================== - redis: - image: redis:8-alpine - container_name: sub2api-redis - restart: unless-stopped - ulimits: - nofile: - soft: 100000 - hard: 100000 - volumes: - - redis_data:/data - command: > - redis-server - --save 60 1 - --appendonly yes - --appendfsync everysec - ${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}} - environment: - - TZ=${TZ:-Asia/Shanghai} - # REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag) - - REDISCLI_AUTH=${REDIS_PASSWORD:-} - networks: - - sub2api-network - healthcheck: - test: ["CMD", "redis-cli", "ping"] - interval: 10s - timeout: 5s - retries: 5 - start_period: 5s - -# ============================================================================= -# Volumes -# ============================================================================= -volumes: - sub2api_data: - driver: local - postgres_data: - driver: local - redis_data: - driver: local - -# ============================================================================= -# Networks -# ============================================================================= -networks: - sub2api-network: - driver: bridge diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml new file mode 100644 index 00000000..7793e424 --- /dev/null +++ b/deploy/docker-compose.dev.yml @@ -0,0 +1,105 @@ +# ============================================================================= +# Sub2API Docker Compose - Local Development Build +# ============================================================================= +# Build from local source code for testing changes. +# +# Usage: +# cd deploy +# docker compose -f docker-compose.dev.yml up --build +# ============================================================================= + +services: + sub2api: + build: + context: .. + dockerfile: Dockerfile + container_name: sub2api-dev + restart: unless-stopped + ports: + - "${BIND_HOST:-127.0.0.1}:${SERVER_PORT:-8080}:8080" + volumes: + - ./data:/app/data + environment: + - AUTO_SETUP=true + - SERVER_HOST=0.0.0.0 + - SERVER_PORT=8080 + - SERVER_MODE=debug + - RUN_MODE=${RUN_MODE:-standard} + - DATABASE_HOST=postgres + - DATABASE_PORT=5432 + - DATABASE_USER=${POSTGRES_USER:-sub2api} + - DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} + - DATABASE_DBNAME=${POSTGRES_DB:-sub2api} + - DATABASE_SSLMODE=disable + - REDIS_HOST=redis + - REDIS_PORT=6379 + - REDIS_PASSWORD=${REDIS_PASSWORD:-} + - REDIS_DB=${REDIS_DB:-0} + - ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local} + - ADMIN_PASSWORD=${ADMIN_PASSWORD:-} + - JWT_SECRET=${JWT_SECRET:-} + - TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-} + - TZ=${TZ:-Asia/Shanghai} + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + networks: + - sub2api-network + healthcheck: + test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + + postgres: + image: postgres:18-alpine + container_name: sub2api-postgres-dev + restart: unless-stopped + volumes: + - ./postgres_data:/var/lib/postgresql/data + environment: + - POSTGRES_USER=${POSTGRES_USER:-sub2api} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} + - POSTGRES_DB=${POSTGRES_DB:-sub2api} + - PGDATA=/var/lib/postgresql/data + - TZ=${TZ:-Asia/Shanghai} + networks: + - sub2api-network + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 10s + + redis: + image: redis:8-alpine + container_name: sub2api-redis-dev + restart: unless-stopped + volumes: + - ./redis_data:/data + command: > + sh -c ' + redis-server + --save 60 1 + --appendonly yes + --appendfsync everysec + ${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}' + environment: + - TZ=${TZ:-Asia/Shanghai} + - REDISCLI_AUTH=${REDIS_PASSWORD:-} + networks: + - sub2api-network + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 5s + +networks: + sub2api-network: + driver: bridge diff --git a/deploy/docker-compose.local.yml b/deploy/docker-compose.local.yml index 0ef397df..d404ac0b 100644 --- a/deploy/docker-compose.local.yml +++ b/deploy/docker-compose.local.yml @@ -154,7 +154,7 @@ services: networks: - sub2api-network healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] interval: 30s timeout: 10s retries: 3 diff --git a/deploy/docker-compose.override.yml.example b/deploy/docker-compose.override.yml.example deleted file mode 100644 index 7157f212..00000000 --- a/deploy/docker-compose.override.yml.example +++ /dev/null @@ -1,150 +0,0 @@ -# ============================================================================= -# Docker Compose Override Configuration Example -# ============================================================================= -# This file provides examples for customizing the Docker Compose setup. -# Copy this file to docker-compose.override.yml and modify as needed. -# -# Usage: -# cp docker-compose.override.yml.example docker-compose.override.yml -# # Edit docker-compose.override.yml with your settings -# docker-compose up -d -# -# IMPORTANT: docker-compose.override.yml is gitignored and will not be committed. -# ============================================================================= - -# ============================================================================= -# Scenario 1: Use External Database and Redis (Recommended for Production) -# ============================================================================= -# Use this when you have PostgreSQL and Redis running on the host machine -# or on separate servers. -# -# Prerequisites: -# - PostgreSQL running on host (accessible via host.docker.internal) -# - Redis running on host (accessible via host.docker.internal) -# - Update DATABASE_PORT and REDIS_PORT in .env file if using non-standard ports -# -# Security Notes: -# - Ensure PostgreSQL pg_hba.conf allows connections from Docker network -# - Use strong passwords for database and Redis -# - Consider using SSL/TLS for database connections in production -# ============================================================================= - -services: - sub2api: - # Remove dependencies on containerized postgres/redis - depends_on: [] - - # Enable access to host machine services - extra_hosts: - - "host.docker.internal:host-gateway" - - # Override database and Redis connection settings - environment: - # PostgreSQL Configuration - DATABASE_HOST: host.docker.internal - DATABASE_PORT: "5678" # Change to your PostgreSQL port - # DATABASE_USER: postgres # Uncomment to override - # DATABASE_PASSWORD: your_password # Uncomment to override - # DATABASE_DBNAME: sub2api # Uncomment to override - - # Redis Configuration - REDIS_HOST: host.docker.internal - REDIS_PORT: "6379" # Change to your Redis port - # REDIS_PASSWORD: your_redis_password # Uncomment if Redis requires auth - # REDIS_DB: 0 # Uncomment to override - - # Disable containerized PostgreSQL - postgres: - deploy: - replicas: 0 - scale: 0 - - # Disable containerized Redis - redis: - deploy: - replicas: 0 - scale: 0 - -# ============================================================================= -# Scenario 2: Development with Local Services (Alternative) -# ============================================================================= -# Uncomment this section if you want to use the containerized postgres/redis -# but expose their ports for local development tools. -# -# Usage: Comment out Scenario 1 above and uncomment this section. -# ============================================================================= - -# services: -# sub2api: -# # Keep default dependencies -# pass -# -# postgres: -# ports: -# - "127.0.0.1:5432:5432" # Expose PostgreSQL on localhost -# -# redis: -# ports: -# - "127.0.0.1:6379:6379" # Expose Redis on localhost - -# ============================================================================= -# Scenario 3: Custom Network Configuration -# ============================================================================= -# Uncomment if you need to connect to an existing Docker network -# ============================================================================= - -# networks: -# default: -# external: true -# name: your-existing-network - -# ============================================================================= -# Scenario 4: Resource Limits (Production) -# ============================================================================= -# Uncomment to set resource limits for the sub2api container -# ============================================================================= - -# services: -# sub2api: -# deploy: -# resources: -# limits: -# cpus: '2.0' -# memory: 2G -# reservations: -# cpus: '1.0' -# memory: 1G - -# ============================================================================= -# Scenario 5: Custom Volumes -# ============================================================================= -# Uncomment to mount additional volumes (e.g., for logs, backups) -# ============================================================================= - -# services: -# sub2api: -# volumes: -# - ./logs:/app/logs -# - ./backups:/app/backups - -# ============================================================================= -# Scenario 6: 启用宿主机 datamanagementd(数据管理) -# ============================================================================= -# 说明: -# - datamanagementd 运行在宿主机(systemd 或手动) -# - 主进程固定探测 /tmp/sub2api-datamanagement.sock -# - 需要把宿主机 socket 挂载到容器内同路径 -# -# services: -# sub2api: -# volumes: -# - /tmp/sub2api-datamanagement.sock:/tmp/sub2api-datamanagement.sock - -# ============================================================================= -# Additional Notes -# ============================================================================= -# - This file overrides settings in docker-compose.yml -# - Environment variables in .env file take precedence -# - For more information, see: https://docs.docker.com/compose/extends/ -# - Check the main README.md for detailed configuration instructions -# ============================================================================= diff --git a/deploy/docker-compose.standalone.yml b/deploy/docker-compose.standalone.yml index 7676fb97..df0ccfcc 100644 --- a/deploy/docker-compose.standalone.yml +++ b/deploy/docker-compose.standalone.yml @@ -94,7 +94,7 @@ services: - GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-} - ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-} healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] interval: 30s timeout: 10s retries: 3 diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 5694fbe5..99b05446 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -127,7 +127,7 @@ services: - sub2api-network - 1panel-network healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] interval: 30s timeout: 10s retries: 3 diff --git a/deploy/docker-deploy.sh b/deploy/docker-deploy.sh index 1e4ce81f..a07f4f41 100644 --- a/deploy/docker-deploy.sh +++ b/deploy/docker-deploy.sh @@ -8,7 +8,7 @@ # - Creates necessary data directories # # After running this script, you can start services with: -# docker-compose -f docker-compose.local.yml up -d +# docker-compose up -d # ============================================================================= set -e @@ -65,7 +65,7 @@ main() { fi # Check if deployment already exists - if [ -f "docker-compose.local.yml" ] && [ -f ".env" ]; then + if [ -f "docker-compose.yml" ] && [ -f ".env" ]; then print_warning "Deployment files already exist in current directory." read -p "Overwrite existing files? (y/N): " -r echo @@ -75,17 +75,17 @@ main() { fi fi - # Download docker-compose.local.yml - print_info "Downloading docker-compose.local.yml..." + # Download docker-compose.local.yml and save as docker-compose.yml + print_info "Downloading docker-compose.yml..." if command_exists curl; then - curl -sSL "${GITHUB_RAW_URL}/docker-compose.local.yml" -o docker-compose.local.yml + curl -sSL "${GITHUB_RAW_URL}/docker-compose.local.yml" -o docker-compose.yml elif command_exists wget; then - wget -q "${GITHUB_RAW_URL}/docker-compose.local.yml" -O docker-compose.local.yml + wget -q "${GITHUB_RAW_URL}/docker-compose.local.yml" -O docker-compose.yml else print_error "Neither curl nor wget is installed. Please install one of them." exit 1 fi - print_success "Downloaded docker-compose.local.yml" + print_success "Downloaded docker-compose.yml" # Download .env.example print_info "Downloading .env.example..." @@ -144,7 +144,7 @@ main() { print_warning "Please keep them secure and do not share publicly!" echo "" echo "Directory structure:" - echo " docker-compose.local.yml - Docker Compose configuration" + echo " docker-compose.yml - Docker Compose configuration" echo " .env - Environment variables (generated secrets)" echo " .env.example - Example template (for reference)" echo " data/ - Application data (will be created on first run)" @@ -154,10 +154,10 @@ main() { echo "Next steps:" echo " 1. (Optional) Edit .env to customize configuration" echo " 2. Start services:" - echo " docker-compose -f docker-compose.local.yml up -d" + echo " docker-compose up -d" echo "" echo " 3. View logs:" - echo " docker-compose -f docker-compose.local.yml logs -f sub2api" + echo " docker-compose logs -f sub2api" echo "" echo " 4. Access Web UI:" echo " http://localhost:8080" diff --git a/deploy/flow.md b/deploy/flow.md deleted file mode 100644 index 0904c72f..00000000 --- a/deploy/flow.md +++ /dev/null @@ -1,222 +0,0 @@ -```mermaid -flowchart TD - %% Master dispatch - A[HTTP Request] --> B{Route} - B -->|v1 messages| GA0 - B -->|openai v1 responses| OA0 - B -->|v1beta models model action| GM0 - B -->|v1 messages count tokens| GT0 - B -->|v1beta models list or get| GL0 - - %% ========================= - %% FLOW A: Claude Gateway - %% ========================= - subgraph FLOW_A["v1 messages Claude Gateway"] - GA0[Auth middleware] --> GA1[Read body] - GA1 -->|empty| GA1E[400 invalid_request_error] - GA1 --> GA2[ParseGatewayRequest] - GA2 -->|parse error| GA2E[400 invalid_request_error] - GA2 --> GA3{model present} - GA3 -->|no| GA3E[400 invalid_request_error] - GA3 --> GA4[streamStarted false] - GA4 --> GA5[IncrementWaitCount user] - GA5 -->|queue full| GA5E[429 rate_limit_error] - GA5 --> GA6[AcquireUserSlotWithWait] - GA6 -->|timeout or fail| GA6E[429 rate_limit_error] - GA6 --> GA7[BillingEligibility check post wait] - GA7 -->|fail| GA7E[403 billing_error] - GA7 --> GA8[Generate sessionHash] - GA8 --> GA9[Resolve platform] - GA9 --> GA10{platform gemini} - GA10 -->|yes| GA10Y[sessionKey gemini hash] - GA10 -->|no| GA10N[sessionKey hash] - GA10Y --> GA11 - GA10N --> GA11 - - GA11[SelectAccountWithLoadAwareness] -->|err and no failed| GA11E1[503 no available accounts] - GA11 -->|err and failed| GA11E2[map failover error] - GA11 --> GA12[Warmup intercept] - GA12 -->|yes| GA12Y[return mock and release if held] - GA12 -->|no| GA13[Acquire account slot or wait] - GA13 -->|wait queue full| GA13E1[429 rate_limit_error] - GA13 -->|wait timeout| GA13E2[429 concurrency limit] - GA13 --> GA14[BindStickySession if waited] - GA14 --> GA15{account platform antigravity} - GA15 -->|yes| GA15Y[ForwardGemini antigravity] - GA15 -->|no| GA15N[Forward Claude] - GA15Y --> GA16[Release account slot and dec account wait] - GA15N --> GA16 - GA16 --> GA17{UpstreamFailoverError} - GA17 -->|yes| GA18[mark failedAccountIDs and map error if exceed] - GA18 -->|loop| GA11 - GA17 -->|no| GA19[success async RecordUsage and return] - GA19 --> GA20[defer release user slot and dec wait count] - end - - %% ========================= - %% FLOW B: OpenAI - %% ========================= - subgraph FLOW_B["openai v1 responses"] - OA0[Auth middleware] --> OA1[Read body] - OA1 -->|empty| OA1E[400 invalid_request_error] - OA1 --> OA2[json Unmarshal body] - OA2 -->|parse error| OA2E[400 invalid_request_error] - OA2 --> OA3{model present} - OA3 -->|no| OA3E[400 invalid_request_error] - OA3 --> OA4{User Agent Codex CLI} - OA4 -->|no| OA4N[set default instructions] - OA4 -->|yes| OA4Y[no change] - OA4N --> OA5 - OA4Y --> OA5 - OA5[streamStarted false] --> OA6[IncrementWaitCount user] - OA6 -->|queue full| OA6E[429 rate_limit_error] - OA6 --> OA7[AcquireUserSlotWithWait] - OA7 -->|timeout or fail| OA7E[429 rate_limit_error] - OA7 --> OA8[BillingEligibility check post wait] - OA8 -->|fail| OA8E[403 billing_error] - OA8 --> OA9[sessionHash sha256 session_id] - OA9 --> OA10[SelectAccountWithLoadAwareness] - OA10 -->|err and no failed| OA10E1[503 no available accounts] - OA10 -->|err and failed| OA10E2[map failover error] - OA10 --> OA11[Acquire account slot or wait] - OA11 -->|wait queue full| OA11E1[429 rate_limit_error] - OA11 -->|wait timeout| OA11E2[429 concurrency limit] - OA11 --> OA12[BindStickySession openai hash if waited] - OA12 --> OA13[Forward OpenAI upstream] - OA13 --> OA14[Release account slot and dec account wait] - OA14 --> OA15{UpstreamFailoverError} - OA15 -->|yes| OA16[mark failedAccountIDs and map error if exceed] - OA16 -->|loop| OA10 - OA15 -->|no| OA17[success async RecordUsage and return] - OA17 --> OA18[defer release user slot and dec wait count] - end - - %% ========================= - %% FLOW C: Gemini Native - %% ========================= - subgraph FLOW_C["v1beta models model action Gemini Native"] - GM0[Auth middleware] --> GM1[Validate platform] - GM1 -->|invalid| GM1E[400 googleError] - GM1 --> GM2[Parse path modelName action] - GM2 -->|invalid| GM2E[400 googleError] - GM2 --> GM3{action supported} - GM3 -->|no| GM3E[404 googleError] - GM3 --> GM4[Read body] - GM4 -->|empty| GM4E[400 googleError] - GM4 --> GM5[streamStarted false] - GM5 --> GM6[IncrementWaitCount user] - GM6 -->|queue full| GM6E[429 googleError] - GM6 --> GM7[AcquireUserSlotWithWait] - GM7 -->|timeout or fail| GM7E[429 googleError] - GM7 --> GM8[BillingEligibility check post wait] - GM8 -->|fail| GM8E[403 googleError] - GM8 --> GM9[Generate sessionHash] - GM9 --> GM10[sessionKey gemini hash] - GM10 --> GM11[SelectAccountWithLoadAwareness] - GM11 -->|err and no failed| GM11E1[503 googleError] - GM11 -->|err and failed| GM11E2[mapGeminiUpstreamError] - GM11 --> GM12[Acquire account slot or wait] - GM12 -->|wait queue full| GM12E1[429 googleError] - GM12 -->|wait timeout| GM12E2[429 googleError] - GM12 --> GM13[BindStickySession if waited] - GM13 --> GM14{account platform antigravity} - GM14 -->|yes| GM14Y[ForwardGemini antigravity] - GM14 -->|no| GM14N[ForwardNative] - GM14Y --> GM15[Release account slot and dec account wait] - GM14N --> GM15 - GM15 --> GM16{UpstreamFailoverError} - GM16 -->|yes| GM17[mark failedAccountIDs and map error if exceed] - GM17 -->|loop| GM11 - GM16 -->|no| GM18[success async RecordUsage and return] - GM18 --> GM19[defer release user slot and dec wait count] - end - - %% ========================= - %% FLOW D: CountTokens - %% ========================= - subgraph FLOW_D["v1 messages count tokens"] - GT0[Auth middleware] --> GT1[Read body] - GT1 -->|empty| GT1E[400 invalid_request_error] - GT1 --> GT2[ParseGatewayRequest] - GT2 -->|parse error| GT2E[400 invalid_request_error] - GT2 --> GT3{model present} - GT3 -->|no| GT3E[400 invalid_request_error] - GT3 --> GT4[BillingEligibility check] - GT4 -->|fail| GT4E[403 billing_error] - GT4 --> GT5[ForwardCountTokens] - end - - %% ========================= - %% FLOW E: Gemini Models List Get - %% ========================= - subgraph FLOW_E["v1beta models list or get"] - GL0[Auth middleware] --> GL1[Validate platform] - GL1 -->|invalid| GL1E[400 googleError] - GL1 --> GL2{force platform antigravity} - GL2 -->|yes| GL2Y[return static fallback models] - GL2 -->|no| GL3[SelectAccountForAIStudioEndpoints] - GL3 -->|no gemini and has antigravity| GL3Y[return fallback models] - GL3 -->|no accounts| GL3E[503 googleError] - GL3 --> GL4[ForwardAIStudioGET] - GL4 -->|error| GL4E[502 googleError] - GL4 --> GL5[Passthrough response or fallback] - end - - %% ========================= - %% SHARED: Account Selection - %% ========================= - subgraph SELECT["SelectAccountWithLoadAwareness detail"] - S0[Start] --> S1{concurrencyService nil OR load batch disabled} - S1 -->|yes| S2[SelectAccountForModelWithExclusions legacy] - S2 --> S3[tryAcquireAccountSlot] - S3 -->|acquired| S3Y[SelectionResult Acquired true ReleaseFunc] - S3 -->|not acquired| S3N[WaitPlan FallbackTimeout MaxWaiting] - S1 -->|no| S4[Resolve platform] - S4 --> S5[List schedulable accounts] - S5 --> S6[Layer1 Sticky session] - S6 -->|hit and valid| S6A[tryAcquireAccountSlot] - S6A -->|acquired| S6AY[SelectionResult Acquired true] - S6A -->|not acquired and waitingCount < StickyMax| S6AN[WaitPlan StickyTimeout Max] - S6 --> S7[Layer2 Load aware] - S7 --> S7A[Load batch concurrency plus wait to loadRate] - S7A --> S7B[Sort priority load LRU OAuth prefer for Gemini] - S7B --> S7C[tryAcquireAccountSlot in order] - S7C -->|first success| S7CY[SelectionResult Acquired true] - S7C -->|none| S8[Layer3 Fallback wait] - S8 --> S8A[Sort priority LRU] - S8A --> S8B[WaitPlan FallbackTimeout Max] - end - - %% ========================= - %% SHARED: Wait Acquire - %% ========================= - subgraph WAIT["AcquireXSlotWithWait detail"] - W0[Try AcquireXSlot immediately] -->|acquired| W1[return ReleaseFunc] - W0 -->|not acquired| W2[Wait loop with timeout] - W2 --> W3[Backoff 100ms x1.5 jitter max2s] - W2 --> W4[If streaming and ping format send SSE ping] - W2 --> W5[Retry AcquireXSlot on timer] - W5 -->|acquired| W1 - W2 -->|timeout| W6[ConcurrencyError IsTimeout true] - end - - %% ========================= - %% SHARED: Account Wait Queue - %% ========================= - subgraph AQ["Account Wait Queue Redis Lua"] - Q1[IncrementAccountWaitCount] --> Q2{current >= max} - Q2 -->|yes| Q2Y[return false] - Q2 -->|no| Q3[INCR and if first set TTL] - Q3 --> Q4[return true] - Q5[DecrementAccountWaitCount] --> Q6[if current > 0 then DECR] - end - - %% ========================= - %% SHARED: Background cleanup - %% ========================= - subgraph CLEANUP["Slot Cleanup Worker"] - C0[StartSlotCleanupWorker interval] --> C1[List schedulable accounts] - C1 --> C2[CleanupExpiredAccountSlots per account] - C2 --> C3[Repeat every interval] - end -``` diff --git a/deploy/install-datamanagementd.sh b/deploy/install-datamanagementd.sh old mode 100755 new mode 100644 diff --git a/docs/ADMIN_PAYMENT_INTEGRATION_API.md b/docs/ADMIN_PAYMENT_INTEGRATION_API.md index 4cc21594..f674f86c 100644 --- a/docs/ADMIN_PAYMENT_INTEGRATION_API.md +++ b/docs/ADMIN_PAYMENT_INTEGRATION_API.md @@ -99,16 +99,17 @@ curl -X POST "${BASE}/api/v1/admin/users/123/balance" \ }' ``` -### 4) 购买页 URL Query 透传(iframe / 新窗口一致) -当 Sub2API 打开 `purchase_subscription_url` 时,会统一追加: +### 4) 购买页 / 自定义页面 URL Query 透传(iframe / 新窗口一致) +当 Sub2API 打开 `purchase_subscription_url` 或用户侧自定义页面 iframe URL 时,会统一追加: - `user_id` - `token` - `theme`(`light` / `dark`) +- `lang`(例如 `zh` / `en`,用于向嵌入页传递当前界面语言) - `ui_mode`(固定 `embedded`) 示例: ```text -https://pay.example.com/pay?user_id=123&token=&theme=light&ui_mode=embedded +https://pay.example.com/pay?user_id=123&token=&theme=light&lang=zh&ui_mode=embedded ``` ### 5) 失败处理建议 @@ -218,16 +219,17 @@ curl -X POST "${BASE}/api/v1/admin/users/123/balance" \ }' ``` -### 4) Purchase URL query forwarding (iframe and new tab) -When Sub2API opens `purchase_subscription_url`, it appends: +### 4) Purchase / Custom Page URL query forwarding (iframe and new tab) +When Sub2API opens `purchase_subscription_url` or a user-facing custom page iframe URL, it appends: - `user_id` - `token` - `theme` (`light` / `dark`) +- `lang` (for example `zh` / `en`, used to pass the current UI language to the embedded page) - `ui_mode` (fixed: `embedded`) Example: ```text -https://pay.example.com/pay?user_id=123&token=&theme=light&ui_mode=embedded +https://pay.example.com/pay?user_id=123&token=&theme=light&lang=zh&ui_mode=embedded ``` ### 5) Failure handling recommendations diff --git a/docs/backend-hotspot-api-performance-optimization-20260222.md b/docs/backend-hotspot-api-performance-optimization-20260222.md deleted file mode 100644 index 8290d49c..00000000 --- a/docs/backend-hotspot-api-performance-optimization-20260222.md +++ /dev/null @@ -1,249 +0,0 @@ -# 后端热点 API 性能优化审计与行动计划(2026-02-22) - -## 1. 目标与范围 - -本次文档用于沉淀后端热点 API 的性能审计结果,并给出可执行优化方案。 - -重点链路: -- `POST /v1/messages` -- `POST /v1/responses` -- `POST /sora/v1/chat/completions` -- `POST /v1beta/models/*modelAction`(Gemini 兼容链路) -- 相关调度、计费、Ops 记录链路 - -## 2. 审计方式与结论边界 - -- 审计方式:静态代码审阅(只读),未对生产环境做侵入变更。 -- 结论类型:以“高置信度可优化点”为主,均附 `file:line` 证据。 -- 未覆盖项:本轮未执行压测与火焰图采样,吞吐增益需在压测环境量化确认。 - -## 3. 优先级总览 - -| 优先级 | 数量 | 结论 | -|---|---:|---| -| P0(Critical) | 2 | 存在资源失控风险,建议立即修复 | -| P1(High) | 2 | 明确的热点 DB/Redis 放大路径,建议本迭代完成 | -| P2(Medium) | 4 | 可观收益优化项,建议并行排期 | - -## 4. 详细问题清单 - -### 4.1 P0-1:使用量记录为“每请求一个 goroutine”,高峰下可能无界堆积 - -证据位置: -- `backend/internal/handler/gateway_handler.go:435` -- `backend/internal/handler/gateway_handler.go:704` -- `backend/internal/handler/openai_gateway_handler.go:382` -- `backend/internal/handler/sora_gateway_handler.go:400` -- `backend/internal/handler/gemini_v1beta_handler.go:523` - -问题描述: -- 记录用量使用 `go func(...)` 直接异步提交,未设置全局并发上限与排队背压。 -- 当 DB/Redis 变慢时,goroutine 数会随请求持续累积。 - -性能影响: -- `goroutine` 激增导致调度开销上升与内存占用增加。 -- 与数据库连接池(默认 `max_open_conns=256`)竞争,放大尾延迟。 - -优化建议: -- 引入“有界队列 + 固定 worker 池”替代每请求 goroutine。 -- 队列满时采用明确策略:丢弃(采样告警)或降级为同步短路。 -- 为 `RecordUsage` 路径增加超时、重试上限与失败计数指标。 - -验收指标: -- 峰值 `goroutines` 稳定,无线性增长。 -- 用量记录成功率、丢弃率、队列长度可观测。 - ---- - -### 4.2 P0-2:Ops 错误日志队列携带原始请求体,存在内存放大风险 - -证据位置: -- 队列容量与 job 结构:`backend/internal/handler/ops_error_logger.go:38`、`backend/internal/handler/ops_error_logger.go:43` -- 入队逻辑:`backend/internal/handler/ops_error_logger.go:132` -- 请求体放入 context:`backend/internal/handler/ops_error_logger.go:261` -- 读取并入队:`backend/internal/handler/ops_error_logger.go:548`、`backend/internal/handler/ops_error_logger.go:563`、`backend/internal/handler/ops_error_logger.go:727`、`backend/internal/handler/ops_error_logger.go:737` -- 入库前才裁剪:`backend/internal/service/ops_service.go:332`、`backend/internal/service/ops_service.go:339` -- 请求体默认上限:`backend/internal/config/config.go:1082`、`backend/internal/config/config.go:1086` - -问题描述: -- 队列元素包含 `[]byte requestBody`,在请求体较大且错误风暴时会显著占用内存。 -- 当前裁剪发生在 worker 消费时,而不是入队前。 - -性能影响: -- 容易造成瞬时高内存与频繁 GC。 -- 极端情况下可能触发 OOM 或服务抖动。 - -优化建议: -- 入队前进行“脱敏 + 裁剪”,仅保留小尺寸结构化片段(建议 8KB~16KB)。 -- 队列存放轻量 DTO,避免持有大块 `[]byte`。 -- 按错误类型控制采样率,避免同类错误洪峰时日志放大。 - -验收指标: -- Ops 错误风暴期间 RSS/GC 次数显著下降。 -- 队列满时系统稳定且告警可见。 - ---- - -### 4.3 P1-1:窗口费用检查在缓存 miss 时逐账号做 DB 聚合 - -证据位置: -- 候选筛选多处调用:`backend/internal/service/gateway_service.go:1109`、`backend/internal/service/gateway_service.go:1137`、`backend/internal/service/gateway_service.go:1291`、`backend/internal/service/gateway_service.go:1354` -- miss 后单账号聚合:`backend/internal/service/gateway_service.go:1791` -- SQL 聚合实现:`backend/internal/repository/usage_log_repo.go:889` -- 窗口费用缓存 TTL:`backend/internal/repository/session_limit_cache.go:33` -- 已有批量读取接口但未利用:`backend/internal/repository/session_limit_cache.go:310` - -问题描述: -- 路由候选过滤阶段频繁调用窗口费用检查。 -- 缓存未命中时逐账号执行聚合查询,账号多时放大 DB 压力。 - -性能影响: -- 路由耗时上升,数据库聚合 QPS 增长。 -- 高并发下可能形成“缓存抖动 + 聚合风暴”。 - -优化建议: -- 先批量 `GetWindowCostBatch`,仅对 miss 账号执行批量 SQL 聚合。 -- 将聚合结果批量回写缓存,降低重复查询。 -- 评估窗口费用缓存 TTL 与刷新策略,减少抖动。 - -验收指标: -- 路由阶段 DB 查询次数下降。 -- `SelectAccountWithLoadAwareness` 平均耗时下降。 - ---- - -### 4.4 P1-2:记录用量时每次查询用户分组倍率,形成稳定 DB 热点 - -证据位置: -- `backend/internal/service/gateway_service.go:5316` -- `backend/internal/service/gateway_service.go:5531` -- `backend/internal/repository/user_group_rate_repo.go:45` - -问题描述: -- `RecordUsage` 与 `RecordUsageWithLongContext` 每次都执行 `GetByUserAndGroup`。 -- 热路径重复读数据库,且与 usage 写入、扣费路径竞争连接池。 - -性能影响: -- 增加 DB 往返与延迟,降低热点接口吞吐。 - -优化建议: -- 在鉴权或路由阶段预热倍率并挂载上下文复用。 -- 引入 L1/L2 缓存(短 TTL + singleflight),减少重复 SQL。 - -验收指标: -- `GetByUserAndGroup` 调用量明显下降。 -- 计费链路 p95 延迟下降。 - ---- - -### 4.5 P2-1:Claude 消息链路重复 JSON 解析 - -证据位置: -- 首次解析:`backend/internal/handler/gateway_handler.go:129` -- 二次解析入口:`backend/internal/handler/gateway_handler.go:146` -- 二次 `json.Unmarshal`:`backend/internal/handler/gateway_helper.go:22`、`backend/internal/handler/gateway_helper.go:26` - -问题描述: -- 同一请求先 `ParseGatewayRequest`,后 `SetClaudeCodeClientContext` 再做 `Unmarshal`。 - -性能影响: -- 增加 CPU 与内存分配,尤其对大 `messages` 请求更明显。 - -优化建议: -- 仅在 `User-Agent` 命中 Claude CLI 规则后再做 body 深解析。 -- 或直接复用首轮解析结果,避免重复反序列化。 - ---- - -### 4.6 P2-2:同一请求中粘性会话账号查询存在重复 Redis 读取 - -证据位置: -- Handler 预取:`backend/internal/handler/gateway_handler.go:242` -- Service 再取:`backend/internal/service/gateway_service.go:941`、`backend/internal/service/gateway_service.go:1129`、`backend/internal/service/gateway_service.go:1277` - -问题描述: -- 同一会话映射在同请求链路被多次读取。 - -性能影响: -- 增加 Redis RTT 与序列化开销,抬高路由延迟。 - -优化建议: -- 统一在 `SelectAccountWithLoadAwareness` 内读取并复用。 -- 或将上层已读到的 sticky account 显式透传给 service。 - ---- - -### 4.7 P2-3:并发等待路径存在重复抢槽 - -证据位置: -- 首次 TryAcquire:`backend/internal/handler/gateway_helper.go:182`、`backend/internal/handler/gateway_helper.go:202` -- wait 内再次立即 Acquire:`backend/internal/handler/gateway_helper.go:226`、`backend/internal/handler/gateway_helper.go:230`、`backend/internal/handler/gateway_helper.go:232` - -问题描述: -- 进入 wait 流程后会再做一次“立即抢槽”,与上层 TryAcquire 重复。 - -性能影响: -- 在高并发下增加 Redis 操作次数,放大锁竞争。 - -优化建议: -- wait 流程直接进入退避循环,避免重复立即抢槽。 - ---- - -### 4.8 P2-4:`/v1/models` 每次走仓储查询与对象装配,未复用快照/短缓存 - -证据位置: -- 入口调用:`backend/internal/handler/gateway_handler.go:767` -- 服务查询:`backend/internal/service/gateway_service.go:6152`、`backend/internal/service/gateway_service.go:6154` -- 对象装配:`backend/internal/repository/account_repo.go:1276`、`backend/internal/repository/account_repo.go:1290`、`backend/internal/repository/account_repo.go:1298` - -问题描述: -- 模型列表请求每次都落到账号查询与附加装配,缺少短时缓存。 - -性能影响: -- 高频请求下持续占用 DB 与 CPU。 - -优化建议: -- 以 `groupID + platform` 建 10s~30s 本地缓存。 -- 或复用调度快照 bucket 的可用账号结果做模型聚合。 - -## 5. 建议实施顺序 - -### 阶段 A(立即,P0) -- 将“用量记录每请求 goroutine”改为有界异步管道。 -- Ops 错误日志改为“入队前裁剪 + 轻量队列对象”。 - -### 阶段 B(短期,P1) -- 批量化窗口费用检查(缓存 + SQL 双批量)。 -- 用户分组倍率加缓存/上下文复用。 - -### 阶段 C(中期,P2) -- 消除重复 JSON 解析与重复 sticky 查询。 -- 优化并发等待重复抢槽逻辑。 -- `/v1/models` 接口加入短缓存或快照复用。 - -## 6. 压测与验证建议 - -建议在预发压测以下场景: -- 场景 1:常规成功流量(验证吞吐与延迟)。 -- 场景 2:上游慢响应(验证 goroutine 与队列稳定性)。 -- 场景 3:错误风暴(验证 Ops 队列与内存上限)。 -- 场景 4:多账号大分组路由(验证窗口费用批量化收益)。 - -建议监控指标: -- 进程:`goroutines`、RSS、GC 次数/停顿。 -- API:各热点接口 p50/p95/p99。 -- DB:QPS、慢查询、连接池等待。 -- Redis:命中率、RTT、命令量。 -- 业务:用量记录成功率/丢弃率、Ops 日志丢弃率。 - -## 7. 待补充数据 - -- 生产真实错误率与错误体大小分布。 -- `window_cost_limit` 实际启用账号比例。 -- `/v1/models` 实际调用频次。 -- DB/Redis 当前容量余量与瓶颈点。 - ---- - -如需进入实现阶段,建议按“阶段 A → 阶段 B → 阶段 C”分 PR 推进,每个阶段都附压测报告与回滚方案。 diff --git a/docs/rename_local_migrations_20260202.sql b/docs/rename_local_migrations_20260202.sql deleted file mode 100644 index 911ed17d..00000000 --- a/docs/rename_local_migrations_20260202.sql +++ /dev/null @@ -1,34 +0,0 @@ --- 修正 schema_migrations 中“本地改名”的迁移文件名 --- 适用场景:你已执行过旧文件名的迁移,合并后仅改了自己这边的文件名 - -BEGIN; - -UPDATE schema_migrations -SET filename = '042b_add_ops_system_metrics_switch_count.sql' -WHERE filename = '042_add_ops_system_metrics_switch_count.sql' - AND NOT EXISTS ( - SELECT 1 FROM schema_migrations WHERE filename = '042b_add_ops_system_metrics_switch_count.sql' - ); - -UPDATE schema_migrations -SET filename = '043b_add_group_invalid_request_fallback.sql' -WHERE filename = '043_add_group_invalid_request_fallback.sql' - AND NOT EXISTS ( - SELECT 1 FROM schema_migrations WHERE filename = '043b_add_group_invalid_request_fallback.sql' - ); - -UPDATE schema_migrations -SET filename = '044b_add_group_mcp_xml_inject.sql' -WHERE filename = '044_add_group_mcp_xml_inject.sql' - AND NOT EXISTS ( - SELECT 1 FROM schema_migrations WHERE filename = '044b_add_group_mcp_xml_inject.sql' - ); - -UPDATE schema_migrations -SET filename = '046b_add_group_supported_model_scopes.sql' -WHERE filename = '046_add_group_supported_model_scopes.sql' - AND NOT EXISTS ( - SELECT 1 FROM schema_migrations WHERE filename = '046b_add_group_supported_model_scopes.sql' - ); - -COMMIT; diff --git a/frontend/src/App.vue b/frontend/src/App.vue index b831c9ff..4fc6a7c8 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -1,9 +1,10 @@ diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 30c3d739..64524d51 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -164,27 +164,10 @@

- -
- -
+

{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} @@ -469,7 +452,7 @@ -

+
+
+
+ + +
+ +

{{ t('admin.accounts.loadFactorHint') }}

+
+ +
+
+ + +
+ +

+ {{ t('admin.accounts.quotaControl.rpmLimit.userMsgQueueHint') }} +

+
+
@@ -780,8 +815,12 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue' import Select from '@/components/common/Select.vue' import ProxySelector from '@/components/common/ProxySelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue' +import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' import Icon from '@/components/icons/Icon.vue' -import { buildModelMappingObject as buildModelMappingPayload } from '@/composables/useModelWhitelist' +import { + buildModelMappingObject as buildModelMappingPayload, + getPresetMappingsByPlatform +} from '@/composables/useModelWhitelist' interface Props { show: boolean @@ -813,26 +852,20 @@ const allAnthropicOAuthOrSetupToken = computed(() => { ) }) -const platformModelPrefix: Record = { - anthropic: ['claude-'], - antigravity: ['claude-', 'gemini-', 'gpt-oss-', 'tab_'], - openai: ['gpt-'], - gemini: ['gemini-'], - sora: [] -} - -const filteredModels = computed(() => { - if (props.selectedPlatforms.length === 0) return allModels - const prefixes = [...new Set(props.selectedPlatforms.flatMap(p => platformModelPrefix[p] || []))] - if (prefixes.length === 0) return allModels - return allModels.filter(m => prefixes.some(prefix => m.value.startsWith(prefix))) -}) - const filteredPresets = computed(() => { - if (props.selectedPlatforms.length === 0) return presetMappings - const prefixes = [...new Set(props.selectedPlatforms.flatMap(p => platformModelPrefix[p] || []))] - if (prefixes.length === 0) return presetMappings - return presetMappings.filter(m => prefixes.some(prefix => m.from.startsWith(prefix))) + if (props.selectedPlatforms.length === 0) return [] + + const dedupedPresets = new Map[number]>() + for (const platform of props.selectedPlatforms) { + for (const preset of getPresetMappingsByPlatform(platform)) { + const key = `${preset.from}=>${preset.to}` + if (!dedupedPresets.has(key)) { + dedupedPresets.set(key, preset) + } + } + } + + return Array.from(dedupedPresets.values()) }) // Model mapping type @@ -848,6 +881,7 @@ const enableCustomErrorCodes = ref(false) const enableInterceptWarmup = ref(false) const enableProxy = ref(false) const enableConcurrency = ref(false) +const enableLoadFactor = ref(false) const enablePriority = ref(false) const enableRateMultiplier = ref(false) const enableStatus = ref(false) @@ -868,6 +902,7 @@ const customErrorCodeInput = ref(null) const interceptWarmupRequests = ref(false) const proxyId = ref(null) const concurrency = ref(1) +const loadFactor = ref(null) const priority = ref(1) const rateMultiplier = ref(1) const status = ref<'active' | 'inactive'>('active') @@ -876,190 +911,12 @@ const rpmLimitEnabled = ref(false) const bulkBaseRpm = ref(null) const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') const bulkRpmStickyBuffer = ref(null) - -// All models list (combined Anthropic + OpenAI + Gemini) -const allModels = [ - { value: 'claude-opus-4-6', label: 'Claude Opus 4.6' }, - { value: 'claude-sonnet-4-6', label: 'Claude Sonnet 4.6' }, - { value: 'claude-opus-4-5-20251101', label: 'Claude Opus 4.5' }, - { value: 'claude-sonnet-4-20250514', label: 'Claude Sonnet 4' }, - { value: 'claude-sonnet-4-5-20250929', label: 'Claude Sonnet 4.5' }, - { value: 'claude-3-5-haiku-20241022', label: 'Claude 3.5 Haiku' }, - { value: 'claude-haiku-4-5-20251001', label: 'Claude Haiku 4.5' }, - { value: 'claude-3-opus-20240229', label: 'Claude 3 Opus' }, - { value: 'claude-3-5-sonnet-20241022', label: 'Claude 3.5 Sonnet' }, - { value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' }, - { value: 'gpt-5.3-codex', label: 'GPT-5.3 Codex' }, - { value: 'gpt-5.3-codex-spark', label: 'GPT-5.3 Codex Spark' }, - { value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' }, - { value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' }, - { value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' }, - { value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' }, - { value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' }, - { value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' }, - { value: 'gpt-5-2025-08-07', label: 'GPT-5' }, - { value: 'gemini-2.0-flash', label: 'Gemini 2.0 Flash' }, - { value: 'gemini-2.5-flash', label: 'Gemini 2.5 Flash' }, - { value: 'gemini-2.5-pro', label: 'Gemini 2.5 Pro' }, - { value: 'gemini-3.1-flash-image', label: 'Gemini 3.1 Flash Image' }, - { value: 'gemini-3-pro-image', label: 'Gemini 3 Pro Image (Legacy)' }, - { value: 'gemini-3-flash-preview', label: 'Gemini 3 Flash Preview' }, - { value: 'gemini-3-pro-preview', label: 'Gemini 3 Pro Preview' } -] - -// Preset mappings (combined Anthropic + OpenAI + Gemini) -const presetMappings = [ - { - label: 'Sonnet 4', - from: 'claude-sonnet-4-20250514', - to: 'claude-sonnet-4-20250514', - color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' - }, - { - label: 'Sonnet 4.5', - from: 'claude-sonnet-4-5-20250929', - to: 'claude-sonnet-4-5-20250929', - color: - 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' - }, - { - label: 'Opus 4.5', - from: 'claude-opus-4-5-20251101', - to: 'claude-opus-4-5-20251101', - color: - 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' - }, - { - label: 'Opus 4.6', - from: 'claude-opus-4-6', - to: 'claude-opus-4-6-thinking', - color: - 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' - }, - { - label: 'Opus 4.6-thinking', - from: 'claude-opus-4-6-thinking', - to: 'claude-opus-4-6-thinking', - color: - 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' - }, - { - label: 'Sonnet 4.6', - from: 'claude-sonnet-4-6', - to: 'claude-sonnet-4-6', - color: - 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' - }, - { - label: 'Sonnet4→4.6', - from: 'claude-sonnet-4-20250514', - to: 'claude-sonnet-4-6', - color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' - }, - { - label: 'Sonnet4.5→4.6', - from: 'claude-sonnet-4-5-20250929', - to: 'claude-sonnet-4-6', - color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' - }, - { - label: 'Sonnet3.5→4.6', - from: 'claude-3-5-sonnet-20241022', - to: 'claude-sonnet-4-6', - color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400' - }, - { - label: 'Opus4.5→4.6', - from: 'claude-opus-4-5-20251101', - to: 'claude-opus-4-6-thinking', - color: - 'bg-violet-100 text-violet-700 hover:bg-violet-200 dark:bg-violet-900/30 dark:text-violet-400' - }, - { - label: 'Opus->Sonnet', - from: 'claude-opus-4-5-20251101', - to: 'claude-sonnet-4-5-20250929', - color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' - }, - { - label: 'Gemini 3.1 Image', - from: 'gemini-3.1-flash-image', - to: 'gemini-3.1-flash-image', - color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' - }, - { - label: 'G3 Image→3.1', - from: 'gemini-3-pro-image', - to: 'gemini-3.1-flash-image', - color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' - }, - { - label: 'GPT-5.3 Codex', - from: 'gpt-5.3-codex', - to: 'gpt-5.3-codex', - color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' - }, - { - label: 'GPT-5.3 Spark', - from: 'gpt-5.3-codex-spark', - to: 'gpt-5.3-codex-spark', - color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' - }, - { - label: '5.2→5.3', - from: 'gpt-5.2-codex', - to: 'gpt-5.3-codex', - color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400' - }, - { - label: 'GPT-5.2', - from: 'gpt-5.2-2025-12-11', - to: 'gpt-5.2-2025-12-11', - color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' - }, - { - label: 'GPT-5.2 Codex', - from: 'gpt-5.2-codex', - to: 'gpt-5.2-codex', - color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' - }, - { - label: 'Max->Codex', - from: 'gpt-5.1-codex-max', - to: 'gpt-5.1-codex', - color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' - }, - { - label: '3-Pro-Preview→3.1-Pro-High', - from: 'gemini-3-pro-preview', - to: 'gemini-3.1-pro-high', - color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' - }, - { - label: '3-Pro-High→3.1-Pro-High', - from: 'gemini-3-pro-high', - to: 'gemini-3.1-pro-high', - color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' - }, - { - label: '3-Pro-Low→3.1-Pro-Low', - from: 'gemini-3-pro-low', - to: 'gemini-3.1-pro-low', - color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' - }, - { - label: '3-Flash透传', - from: 'gemini-3-flash', - to: 'gemini-3-flash', - color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400' - }, - { - label: '2.5-Flash-Lite透传', - from: 'gemini-2.5-flash-lite', - to: 'gemini-2.5-flash-lite', - color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' - } -] +const userMsgQueueMode = ref(null) +const umqModeOptions = computed(() => [ + { value: '', label: t('admin.accounts.quotaControl.rpmLimit.umqModeOff') }, + { value: 'throttle', label: t('admin.accounts.quotaControl.rpmLimit.umqModeThrottle') }, + { value: 'serialize', label: t('admin.accounts.quotaControl.rpmLimit.umqModeSerialize') }, +]) // Common HTTP error codes const commonErrorCodes = [ @@ -1168,6 +1025,12 @@ const buildUpdatePayload = (): Record | null => { updates.concurrency = concurrency.value } + if (enableLoadFactor.value) { + // 空值/NaN/0 时发送 0(后端约定 <= 0 表示清除) + const lf = loadFactor.value + updates.load_factor = (lf != null && !Number.isNaN(lf) && lf > 0) ? lf : 0 + } + if (enablePriority.value) { updates.priority = priority.value } @@ -1249,6 +1112,14 @@ const buildUpdatePayload = (): Record | null => { updates.extra = extra } + // UMQ mode(独立于 RPM 保存) + if (userMsgQueueMode.value !== null) { + if (!updates.extra) updates.extra = {} + const umqExtra = updates.extra as Record + umqExtra.user_msg_queue_mode = userMsgQueueMode.value // '' = 清除账号级覆盖 + umqExtra.user_msg_queue_enabled = false // 清理旧字段(JSONB merge) + } + return Object.keys(updates).length > 0 ? updates : null } @@ -1305,11 +1176,13 @@ const handleSubmit = async () => { enableInterceptWarmup.value || enableProxy.value || enableConcurrency.value || + enableLoadFactor.value || enablePriority.value || enableRateMultiplier.value || enableStatus.value || enableGroups.value || - enableRpmLimit.value + enableRpmLimit.value || + userMsgQueueMode.value !== null if (!hasAnyFieldEnabled) { appStore.showError(t('admin.accounts.bulkEdit.noFieldsSelected')) @@ -1394,6 +1267,7 @@ watch( enableInterceptWarmup.value = false enableProxy.value = false enableConcurrency.value = false + enableLoadFactor.value = false enablePriority.value = false enableRateMultiplier.value = false enableStatus.value = false @@ -1410,10 +1284,16 @@ watch( interceptWarmupRequests.value = false proxyId.value = null concurrency.value = 1 + loadFactor.value = null priority.value = 1 rateMultiplier.value = 1 status.value = 'active' groupIds.value = [] + rpmLimitEnabled.value = false + bulkBaseRpm.value = null + bulkRpmStrategy.value = 'tiered' + bulkRpmStickyBuffer.value = null + userMsgQueueMode.value = null // Reset mixed channel warning state showMixedChannelWarning.value = false diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 97a6fbce..6f02a9d9 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -232,7 +232,7 @@
-
+
+ + +
@@ -1127,6 +1158,58 @@
+ +
+
+
+ +

+ {{ t('admin.accounts.poolModeHint') }} +

+
+ +
+
+

+ + {{ t('admin.accounts.poolModeInfo') }} +

+
+
+ + +

+ {{ + t('admin.accounts.poolModeRetryCountHint', { + default: DEFAULT_POOL_MODE_RETRY_COUNT, + max: MAX_POOL_MODE_RETRY_COUNT + }) + }} +

+
+
+
@@ -1227,6 +1310,418 @@
+ +
+ +
+ +
+ + +
+
+ + + + + +
+ + +
+ + +
+ + +

{{ t('admin.accounts.bedrockRegionHint') }}

+
+ + +
+ +

{{ t('admin.accounts.bedrockForceGlobalHint') }}

+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ t('admin.accounts.supportsAllModels') }} +

+
+ + +
+
+ + + + +
+ + +
+ +
+
+
+ + +
+
+
+ +

+ {{ t('admin.accounts.poolModeHint') }} +

+
+ +
+
+

+ + {{ t('admin.accounts.poolModeInfo') }} +

+
+
+ + +

+ {{ + t('admin.accounts.poolModeRetryCountHint', { + default: DEFAULT_POOL_MODE_RETRY_COUNT, + max: MAX_POOL_MODE_RETRY_COUNT + }) + }} +

+
+
+
+ + +
+
+

{{ t('admin.accounts.quotaLimit') }}

+

+ {{ t('admin.accounts.quotaLimitHint') }} +

+
+ +
+ + +
+ + +
+

+ {{ t('admin.accounts.openai.modelRestrictionDisabledByPassthrough') }} +

+
+ + +
+
@@ -1625,6 +2120,27 @@ />

{{ t('admin.accounts.quotaControl.rpmLimit.stickyBufferHint') }}

+ +
+ + +
+ +

+ {{ t('admin.accounts.quotaControl.rpmLimit.userMsgQueueHint') }} +

+
+ +
@@ -1728,10 +2244,18 @@ -
+
- + +
+
+ + +

{{ t('admin.accounts.loadFactorHint') }}

@@ -1786,7 +2310,7 @@
- +

- {{ t('admin.accounts.openai.wsModeConcurrencyHint') }} + {{ t(openAIWSModeConcurrencyHintKey) }}

@@ -1925,6 +2449,33 @@
+
+ +
+ + ? + +
+ {{ t('admin.accounts.allowOveragesTooltip') }} +
+
+
+
('oauth-based') // UI selection for account category +const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock'>('oauth-based') // UI selection for account category const addMethod = ref('oauth') // For oauth-based: 'oauth' or 'setup-token' const apiKeyBaseUrl = ref('https://api.anthropic.com') const apiKeyValue = ref('') +const editQuotaLimit = ref(null) +const editQuotaDailyLimit = ref(null) +const editQuotaWeeklyLimit = ref(null) +const editDailyResetMode = ref<'rolling' | 'fixed' | null>(null) +const editDailyResetHour = ref(null) +const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null) +const editWeeklyResetDay = ref(null) +const editWeeklyResetHour = ref(null) +const editResetTimezone = ref(null) const modelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) +const DEFAULT_POOL_MODE_RETRY_COUNT = 3 +const MAX_POOL_MODE_RETRY_COUNT = 10 +const poolModeEnabled = ref(false) +const poolModeRetryCount = ref(DEFAULT_POOL_MODE_RETRY_COUNT) const customErrorCodesEnabled = ref(false) const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) @@ -2452,6 +3018,7 @@ const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OF const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling +const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream const soraAccountType = ref<'oauth' | 'apikey'>('oauth') // For sora: oauth or apikey (upstream) const upstreamBaseUrl = ref('') // For upstream type: base URL @@ -2460,6 +3027,16 @@ const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist' const antigravityWhitelistModels = ref([]) const antigravityModelMappings = ref([]) const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity')) +const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock')) + +// Bedrock credentials +const bedrockAuthMode = ref<'sigv4' | 'apikey'>('sigv4') +const bedrockAccessKeyId = ref('') +const bedrockSecretAccessKey = ref('') +const bedrockSessionToken = ref('') +const bedrockRegion = ref('us-east-1') +const bedrockForceGlobal = ref(false) +const bedrockApiKeyValue = ref('') const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) const getModelMappingKey = createStableObjectKeyResolver('create-model-mapping') @@ -2468,6 +3045,13 @@ const getTempUnschedRuleKey = createStableObjectKeyResolver const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one') const geminiAIStudioOAuthEnabled = ref(false) +function buildAntigravityExtra(): Record | undefined { + const extra: Record = {} + if (mixedScheduling.value) extra.mixed_scheduling = true + if (allowOverages.value) extra.allow_overages = true + return Object.keys(extra).length > 0 ? extra : undefined +} + const showMixedChannelWarning = ref(false) const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>( null @@ -2489,6 +3073,12 @@ const rpmLimitEnabled = ref(false) const baseRpm = ref(null) const rpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') const rpmStickyBuffer = ref(null) +const userMsgQueueMode = ref('') +const umqModeOptions = computed(() => [ + { value: '', label: t('admin.accounts.quotaControl.rpmLimit.umqModeOff') }, + { value: 'throttle', label: t('admin.accounts.quotaControl.rpmLimit.umqModeThrottle') }, + { value: 'serialize', label: t('admin.accounts.quotaControl.rpmLimit.umqModeSerialize') }, +]) const tlsFingerprintEnabled = ref(false) const sessionIdMaskingEnabled = ref(false) const cacheTTLOverrideEnabled = ref(false) @@ -2514,8 +3104,9 @@ const geminiSelectedTier = computed(() => { const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, - { value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') }, - { value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') } + // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 + // { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, + { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) const openaiResponsesWebSocketV2Mode = computed({ @@ -2534,6 +3125,10 @@ const openaiResponsesWebSocketV2Mode = computed({ } }) +const openAIWSModeConcurrencyHintKey = computed(() => + resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value) +) + const isOpenAIModelRestrictionDisabled = computed(() => form.platform === 'openai' && openaiPassthroughEnabled.value ) @@ -2600,6 +3195,7 @@ const form = reactive({ credentials: {} as Record, proxy_id: null as number | null, concurrency: 10, + load_factor: null as number | null, priority: 1, rate_multiplier: 1, group_ids: [] as number[], @@ -2612,6 +3208,10 @@ const isOAuthFlow = computed(() => { if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') { return false } + // Bedrock 类型不需要 OAuth 流程 + if (form.platform === 'anthropic' && accountCategory.value === 'bedrock') { + return false + } return accountCategory.value === 'oauth-based' }) @@ -2679,6 +3279,11 @@ watch( form.type = 'apikey' return } + // Bedrock 类型 + if (form.platform === 'anthropic' && category === 'bedrock') { + form.type = 'bedrock' as AccountType + return + } if (category === 'oauth-based') { form.type = method as AccountType // 'oauth' or 'setup-token' } else { @@ -2712,10 +3317,19 @@ watch( accountCategory.value = 'oauth-based' antigravityAccountType.value = 'oauth' } else { + allowOverages.value = false antigravityWhitelistModels.value = [] antigravityModelMappings.value = [] antigravityModelRestrictionMode.value = 'mapping' } + // Reset Bedrock fields when switching platforms + bedrockAccessKeyId.value = '' + bedrockSecretAccessKey.value = '' + bedrockSessionToken.value = '' + bedrockRegion.value = 'us-east-1' + bedrockForceGlobal.value = false + bedrockAuthMode.value = 'sigv4' + bedrockApiKeyValue.value = '' // Reset Anthropic/Antigravity-specific settings when switching to other platforms if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') { interceptWarmupRequests.value = false @@ -3079,6 +3693,7 @@ const resetForm = () => { form.credentials = {} form.proxy_id = null form.concurrency = 10 + form.load_factor = null form.priority = 1 form.rate_multiplier = 1 form.group_ids = [] @@ -3087,6 +3702,15 @@ const resetForm = () => { addMethod.value = 'oauth' apiKeyBaseUrl.value = 'https://api.anthropic.com' apiKeyValue.value = '' + editQuotaLimit.value = null + editQuotaDailyLimit.value = null + editQuotaWeeklyLimit.value = null + editDailyResetMode.value = null + editDailyResetHour.value = null + editWeeklyResetMode.value = null + editWeeklyResetDay.value = null + editWeeklyResetHour.value = null + editResetTimezone.value = null modelMappings.value = [] modelRestrictionMode.value = 'whitelist' allowedModels.value = [...claudeModels] // Default fill related models @@ -3096,6 +3720,8 @@ const resetForm = () => { fetchAntigravityDefaultMappings().then(mappings => { antigravityModelMappings.value = [...mappings] }) + poolModeEnabled.value = false + poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT customErrorCodesEnabled.value = false selectedErrorCodes.value = [] customErrorCodeInput.value = null @@ -3117,10 +3743,12 @@ const resetForm = () => { baseRpm.value = null rpmStrategy.value = 'tiered' rpmStickyBuffer.value = null + userMsgQueueMode.value = '' tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false cacheTTLOverrideEnabled.value = false cacheTTLOverrideTarget.value = '5m' + allowOverages.value = false antigravityAccountType.value = 'oauth' upstreamBaseUrl.value = '' upstreamApiKey.value = '' @@ -3152,10 +3780,13 @@ const buildOpenAIExtra = (base?: Record): Record = { ...(base || {}) } - extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value - extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value - extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) - extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + if (accountCategory.value === 'oauth-based') { + extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value + extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) + } else if (accountCategory.value === 'apikey') { + extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value + extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + } // 清理兼容旧键,统一改用分类型开关。 delete extra.responses_websockets_v2_enabled delete extra.openai_ws_enabled @@ -3244,6 +3875,20 @@ const handleMixedChannelCancel = () => { clearMixedChannelDialog() } +const normalizePoolModeRetryCount = (value: number) => { + if (!Number.isFinite(value)) { + return DEFAULT_POOL_MODE_RETRY_COUNT + } + const normalized = Math.trunc(value) + if (normalized < 0) { + return 0 + } + if (normalized > MAX_POOL_MODE_RETRY_COUNT) { + return MAX_POOL_MODE_RETRY_COUNT + } + return normalized +} + const handleSubmit = async () => { // For OAuth-based type, handle OAuth flow (goes to step 2) if (isOAuthFlow.value) { @@ -3261,6 +3906,64 @@ const handleSubmit = async () => { return } + // For Bedrock type, create directly + if (form.platform === 'anthropic' && accountCategory.value === 'bedrock') { + if (!form.name.trim()) { + appStore.showError(t('admin.accounts.pleaseEnterAccountName')) + return + } + + const credentials: Record = { + auth_mode: bedrockAuthMode.value, + aws_region: bedrockRegion.value.trim() || 'us-east-1', + } + + if (bedrockAuthMode.value === 'sigv4') { + if (!bedrockAccessKeyId.value.trim()) { + appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired')) + return + } + if (!bedrockSecretAccessKey.value.trim()) { + appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired')) + return + } + credentials.aws_access_key_id = bedrockAccessKeyId.value.trim() + credentials.aws_secret_access_key = bedrockSecretAccessKey.value.trim() + if (bedrockSessionToken.value.trim()) { + credentials.aws_session_token = bedrockSessionToken.value.trim() + } + } else { + if (!bedrockApiKeyValue.value.trim()) { + appStore.showError(t('admin.accounts.bedrockApiKeyRequired')) + return + } + credentials.api_key = bedrockApiKeyValue.value.trim() + } + + if (bedrockForceGlobal.value) { + credentials.aws_force_global = 'true' + } + + // Model mapping + const modelMapping = buildModelMappingObject( + modelRestrictionMode.value, allowedModels.value, modelMappings.value + ) + if (modelMapping) { + credentials.model_mapping = modelMapping + } + + // Pool mode + if (poolModeEnabled.value) { + credentials.pool_mode = true + credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + } + + applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') + + await createAccountAndFinish('anthropic', 'bedrock' as AccountType, credentials) + return + } + // For Antigravity upstream type, create directly if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') { if (!form.name.trim()) { @@ -3294,7 +3997,7 @@ const handleSubmit = async () => { applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') - const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined + const extra = buildAntigravityExtra() await createAccountAndFinish(form.platform, 'apikey', credentials, extra) return } @@ -3343,6 +4046,12 @@ const handleSubmit = async () => { } } + // Add pool mode if enabled + if (poolModeEnabled.value) { + credentials.pool_mode = true + credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + } + // Add custom error codes if enabled if (customErrorCodesEnabled.value) { credentials.custom_error_codes_enabled = true @@ -3446,6 +4155,7 @@ const handleImportAccessToken = async (accessTokenInput: string) => { extra: soraExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3496,15 +4206,46 @@ const createAccountAndFinish = async ( if (!applyTempUnschedConfig(credentials)) { return } + // Inject quota limits for apikey/bedrock accounts + let finalExtra = extra + if (type === 'apikey' || type === 'bedrock') { + const quotaExtra: Record = { ...(extra || {}) } + if (editQuotaLimit.value != null && editQuotaLimit.value > 0) { + quotaExtra.quota_limit = editQuotaLimit.value + } + if (editQuotaDailyLimit.value != null && editQuotaDailyLimit.value > 0) { + quotaExtra.quota_daily_limit = editQuotaDailyLimit.value + } + if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) { + quotaExtra.quota_weekly_limit = editQuotaWeeklyLimit.value + } + // Quota reset mode config + if (editDailyResetMode.value === 'fixed') { + quotaExtra.quota_daily_reset_mode = 'fixed' + quotaExtra.quota_daily_reset_hour = editDailyResetHour.value ?? 0 + } + if (editWeeklyResetMode.value === 'fixed') { + quotaExtra.quota_weekly_reset_mode = 'fixed' + quotaExtra.quota_weekly_reset_day = editWeeklyResetDay.value ?? 1 + quotaExtra.quota_weekly_reset_hour = editWeeklyResetHour.value ?? 0 + } + if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') { + quotaExtra.quota_reset_timezone = editResetTimezone.value || 'UTC' + } + if (Object.keys(quotaExtra).length > 0) { + finalExtra = quotaExtra + } + } await doCreateAccount({ name: form.name, notes: form.notes, platform, type, credentials, - extra, + extra: finalExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3543,6 +4284,14 @@ const handleOpenAIExchange = async (authCode: string) => { const shouldCreateOpenAI = form.platform === 'openai' const shouldCreateSora = form.platform === 'sora' + // Add model mapping for OpenAI OAuth accounts(透传模式下不应用) + if (shouldCreateOpenAI && !isOpenAIModelRestrictionDisabled.value) { + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + credentials.model_mapping = modelMapping + } + } + // 应用临时不可调度配置 if (!applyTempUnschedConfig(credentials)) { return @@ -3560,6 +4309,7 @@ const handleOpenAIExchange = async (authCode: string) => { extra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3589,6 +4339,7 @@ const handleOpenAIExchange = async (authCode: string) => { extra: soraExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3651,6 +4402,14 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record | undefined const extra = buildOpenAIExtra(oauthExtra) + // Add model mapping for OpenAI OAuth accounts(透传模式下不应用) + if (shouldCreateOpenAI && !isOpenAIModelRestrictionDisabled.value) { + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + credentials.model_mapping = modelMapping + } + } + // Generate account name with index for batch const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name @@ -3666,6 +4425,7 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { extra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3693,6 +4453,7 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { extra: soraExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3781,6 +4542,7 @@ const handleSoraValidateST = async (sessionTokenInput: string) => { extra: soraExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3869,6 +4631,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => { extra: {}, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3980,7 +4743,7 @@ const handleAntigravityExchange = async (authCode: string) => { if (antigravityModelMapping) { credentials.model_mapping = antigravityModelMapping } - const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined + const extra = buildAntigravityExtra() await createAccountAndFinish('antigravity', 'oauth', credentials, extra) } catch (error: any) { antigravityOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') @@ -4027,14 +4790,22 @@ const handleAnthropicExchange = async (authCode: string) => { } // Add RPM limit settings - if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) { - extra.base_rpm = baseRpm.value + if (rpmLimitEnabled.value) { + const DEFAULT_BASE_RPM = 15 + extra.base_rpm = (baseRpm.value != null && baseRpm.value > 0) + ? baseRpm.value + : DEFAULT_BASE_RPM extra.rpm_strategy = rpmStrategy.value if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) { extra.rpm_sticky_buffer = rpmStickyBuffer.value } } + // UMQ mode(独立于 RPM) + if (userMsgQueueMode.value) { + extra.user_msg_queue_mode = userMsgQueueMode.value + } + // Add TLS fingerprint settings if (tlsFingerprintEnabled.value) { extra.enable_tls_fingerprint = true @@ -4134,14 +4905,22 @@ const handleCookieAuth = async (sessionKey: string) => { } // Add RPM limit settings - if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) { - extra.base_rpm = baseRpm.value + if (rpmLimitEnabled.value) { + const DEFAULT_BASE_RPM = 15 + extra.base_rpm = (baseRpm.value != null && baseRpm.value > 0) + ? baseRpm.value + : DEFAULT_BASE_RPM extra.rpm_strategy = rpmStrategy.value if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) { extra.rpm_sticky_buffer = rpmStickyBuffer.value } } + // UMQ mode(独立于 RPM) + if (userMsgQueueMode.value) { + extra.user_msg_queue_mode = userMsgQueueMode.value + } + // Add TLS fingerprint settings if (tlsFingerprintEnabled.value) { extra.enable_tls_fingerprint = true @@ -4176,6 +4955,7 @@ const handleCookieAuth = async (sessionKey: string) => { extra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 184eff98..c2f2f7d2 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -251,6 +251,58 @@ + +
+
+
+ +

+ {{ t('admin.accounts.poolModeHint') }} +

+
+ +
+
+

+ + {{ t('admin.accounts.poolModeInfo') }} +

+
+
+ + +

+ {{ + t('admin.accounts.poolModeRetryCountHint', { + default: DEFAULT_POOL_MODE_RETRY_COUNT, + max: MAX_POOL_MODE_RETRY_COUNT + }) + }} +

+
+
+
@@ -351,6 +403,142 @@
+ +
+ + +
+

+ {{ t('admin.accounts.openai.modelRestrictionDisabledByPassthrough') }} +

+
+ + +
+
@@ -375,6 +563,200 @@
+ +
+ + + + +
+ + +

{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}

+
+ + +
+ + +

{{ t('admin.accounts.bedrockRegionHint') }}

+
+ + +
+ +

{{ t('admin.accounts.bedrockForceGlobalHint') }}

+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ t('admin.accounts.supportsAllModels') }} +

+
+ + +
+
+ + + + +
+ + +
+ +
+
+
+ + +
+
+
+ +

+ {{ t('admin.accounts.poolModeHint') }} +

+
+ +
+
+

+ + {{ t('admin.accounts.poolModeInfo') }} +

+
+
+ + +

+ {{ + t('admin.accounts.poolModeRetryCountHint', { + default: DEFAULT_POOL_MODE_RETRY_COUNT, + max: MAX_POOL_MODE_RETRY_COUNT + }) + }} +

+
+
+
+
@@ -650,10 +1032,18 @@
-
+
- + +
+
+ + +

{{ t('admin.accounts.loadFactorHint') }}

@@ -708,7 +1098,7 @@
- +

- {{ t('admin.accounts.openai.wsModeConcurrencyHint') }} + {{ t(openAIWSModeConcurrencyHintKey) }}

@@ -759,6 +1149,36 @@
+ +
+
+

{{ t('admin.accounts.quotaLimit') }}

+

+ {{ t('admin.accounts.quotaLimitHint') }} +

+
+ +
+

{{ t('admin.accounts.quotaControl.rpmLimit.stickyBufferHint') }}

+ +
+ + +
+ +

+ {{ t('admin.accounts.quotaControl.rpmLimit.userMsgQueueHint') }} +

+
+ +
@@ -1169,6 +1610,33 @@ +
+ +
+ + ? + +
+ {{ t('admin.accounts.allowOveragesTooltip') }} +
+
+
+
@@ -1248,14 +1716,16 @@ import Icon from '@/components/icons/Icon.vue' import ProxySelector from '@/components/common/ProxySelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' +import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue' import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { - OPENAI_WS_MODE_DEDICATED, + // OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, - OPENAI_WS_MODE_SHARED, + OPENAI_WS_MODE_PASSTHROUGH, isOpenAIWSModeEnabled, + resolveOpenAIWSModeConcurrencyHintKey, type OpenAIWSMode, resolveOpenAIWSModeFromExtra } from '@/utils/openaiWsMode' @@ -1292,6 +1762,7 @@ const baseUrlHint = computed(() => { }) const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity')) +const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock')) // Model mapping type interface ModelMapping { @@ -1310,15 +1781,31 @@ interface TempUnschedRuleForm { const submitting = ref(false) const editBaseUrl = ref('https://api.anthropic.com') const editApiKey = ref('') +// Bedrock credentials +const editBedrockAccessKeyId = ref('') +const editBedrockSecretAccessKey = ref('') +const editBedrockSessionToken = ref('') +const editBedrockRegion = ref('') +const editBedrockForceGlobal = ref(false) +const editBedrockApiKeyValue = ref('') +const isBedrockAPIKeyMode = computed(() => + props.account?.type === 'bedrock' && + (props.account?.credentials as Record)?.auth_mode === 'apikey' +) const modelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) +const DEFAULT_POOL_MODE_RETRY_COUNT = 3 +const MAX_POOL_MODE_RETRY_COUNT = 10 +const poolModeEnabled = ref(false) +const poolModeRetryCount = ref(DEFAULT_POOL_MODE_RETRY_COUNT) const customErrorCodesEnabled = ref(false) const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) const interceptWarmupRequests = ref(false) const autoPauseOnExpired = ref(false) const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling +const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const antigravityWhitelistModels = ref([]) const antigravityModelMappings = ref([]) @@ -1347,6 +1834,12 @@ const rpmLimitEnabled = ref(false) const baseRpm = ref(null) const rpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') const rpmStickyBuffer = ref(null) +const userMsgQueueMode = ref('') +const umqModeOptions = computed(() => [ + { value: '', label: t('admin.accounts.quotaControl.rpmLimit.umqModeOff') }, + { value: 'throttle', label: t('admin.accounts.quotaControl.rpmLimit.umqModeThrottle') }, + { value: 'serialize', label: t('admin.accounts.quotaControl.rpmLimit.umqModeSerialize') }, +]) const tlsFingerprintEnabled = ref(false) const sessionIdMaskingEnabled = ref(false) const cacheTTLOverrideEnabled = ref(false) @@ -1358,10 +1851,20 @@ const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) +const editQuotaLimit = ref(null) +const editQuotaDailyLimit = ref(null) +const editQuotaWeeklyLimit = ref(null) +const editDailyResetMode = ref<'rolling' | 'fixed' | null>(null) +const editDailyResetHour = ref(null) +const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null) +const editWeeklyResetDay = ref(null) +const editWeeklyResetHour = ref(null) +const editResetTimezone = ref(null) const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, - { value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') }, - { value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') } + // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 + // { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, + { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) const openaiResponsesWebSocketV2Mode = computed({ get: () => { @@ -1378,6 +1881,9 @@ const openaiResponsesWebSocketV2Mode = computed({ openaiOAuthResponsesWebSocketV2Mode.value = mode } }) +const openAIWSModeConcurrencyHintKey = computed(() => + resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value) +) const isOpenAIModelRestrictionDisabled = computed(() => props.account?.platform === 'openai' && openaiPassthroughEnabled.value ) @@ -1433,17 +1939,24 @@ const form = reactive({ notes: '', proxy_id: null as number | null, concurrency: 1, + load_factor: null as number | null, priority: 1, rate_multiplier: 1, - status: 'active' as 'active' | 'inactive', + status: 'active' as 'active' | 'inactive' | 'error', group_ids: [] as number[], expires_at: null as number | null }) -const statusOptions = computed(() => [ - { value: 'active', label: t('common.active') }, - { value: 'inactive', label: t('common.inactive') } -]) +const statusOptions = computed(() => { + const options = [ + { value: 'active', label: t('common.active') }, + { value: 'inactive', label: t('common.inactive') } + ] + if (form.status === 'error') { + options.push({ value: 'error', label: t('admin.accounts.status.error') }) + } + return options +}) const expiresAtInput = computed({ get: () => formatDateTimeLocal(form.expires_at), @@ -1453,6 +1966,20 @@ const expiresAtInput = computed({ }) // Watchers +const normalizePoolModeRetryCount = (value: number) => { + if (!Number.isFinite(value)) { + return DEFAULT_POOL_MODE_RETRY_COUNT + } + const normalized = Math.trunc(value) + if (normalized < 0) { + return 0 + } + if (normalized > MAX_POOL_MODE_RETRY_COUNT) { + return MAX_POOL_MODE_RETRY_COUNT + } + return normalized +} + watch( () => props.account, (newAccount) => { @@ -1466,9 +1993,12 @@ watch( form.notes = newAccount.notes || '' form.proxy_id = newAccount.proxy_id form.concurrency = newAccount.concurrency + form.load_factor = newAccount.load_factor ?? null form.priority = newAccount.priority form.rate_multiplier = newAccount.rate_multiplier ?? 1 - form.status = newAccount.status as 'active' | 'inactive' + form.status = (newAccount.status === 'active' || newAccount.status === 'inactive' || newAccount.status === 'error') + ? newAccount.status + : 'active' form.group_ids = newAccount.group_ids || [] form.expires_at = newAccount.expires_at ?? null @@ -1478,8 +2008,11 @@ watch( autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true // Load mixed scheduling setting (only for antigravity accounts) + mixedScheduling.value = false + allowOverages.value = false const extra = newAccount.extra as Record | undefined mixedScheduling.value = extra?.mixed_scheduling === true + allowOverages.value = extra?.allow_overages === true // Load OpenAI passthrough toggle (OpenAI OAuth/API Key) openaiPassthroughEnabled.value = false @@ -1509,6 +2042,33 @@ watch( anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true } + // Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above) + if (newAccount.type === 'apikey' || newAccount.type === 'bedrock') { + const quotaVal = extra?.quota_limit as number | undefined + editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null + const dailyVal = extra?.quota_daily_limit as number | undefined + editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null + const weeklyVal = extra?.quota_weekly_limit as number | undefined + editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null + // Load quota reset mode config + editDailyResetMode.value = (extra?.quota_daily_reset_mode as 'rolling' | 'fixed') || null + editDailyResetHour.value = (extra?.quota_daily_reset_hour as number) ?? null + editWeeklyResetMode.value = (extra?.quota_weekly_reset_mode as 'rolling' | 'fixed') || null + editWeeklyResetDay.value = (extra?.quota_weekly_reset_day as number) ?? null + editWeeklyResetHour.value = (extra?.quota_weekly_reset_hour as number) ?? null + editResetTimezone.value = (extra?.quota_reset_timezone as string) || null + } else { + editQuotaLimit.value = null + editQuotaDailyLimit.value = null + editQuotaWeeklyLimit.value = null + editDailyResetMode.value = null + editDailyResetHour.value = null + editWeeklyResetMode.value = null + editWeeklyResetDay.value = null + editWeeklyResetHour.value = null + editResetTimezone.value = null + } + // Load antigravity model mapping (Antigravity 只支持映射模式) if (newAccount.platform === 'antigravity') { const credentials = newAccount.credentials as Record | undefined @@ -1583,6 +2143,12 @@ watch( allowedModels.value = [] } + // Load pool mode + poolModeEnabled.value = credentials.pool_mode === true + poolModeRetryCount.value = normalizePoolModeRetryCount( + Number(credentials.pool_mode_retry_count ?? DEFAULT_POOL_MODE_RETRY_COUNT) + ) + // Load custom error codes customErrorCodesEnabled.value = credentials.custom_error_codes_enabled === true const existingErrorCodes = credentials.custom_error_codes as number[] | undefined @@ -1591,6 +2157,50 @@ watch( } else { selectedErrorCodes.value = [] } + } else if (newAccount.type === 'bedrock' && newAccount.credentials) { + const bedrockCreds = newAccount.credentials as Record + const authMode = (bedrockCreds.auth_mode as string) || 'sigv4' + editBedrockRegion.value = (bedrockCreds.aws_region as string) || '' + editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true' + + if (authMode === 'apikey') { + editBedrockApiKeyValue.value = '' + } else { + editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || '' + editBedrockSecretAccessKey.value = '' + editBedrockSessionToken.value = '' + } + + // Load pool mode for bedrock + poolModeEnabled.value = bedrockCreds.pool_mode === true + const retryCount = bedrockCreds.pool_mode_retry_count + poolModeRetryCount.value = (typeof retryCount === 'number' && retryCount >= 0) ? retryCount : DEFAULT_POOL_MODE_RETRY_COUNT + + // Load quota limits for bedrock + const bedrockExtra = (newAccount.extra as Record) || {} + editQuotaLimit.value = typeof bedrockExtra.quota_limit === 'number' ? bedrockExtra.quota_limit : null + editQuotaDailyLimit.value = typeof bedrockExtra.quota_daily_limit === 'number' ? bedrockExtra.quota_daily_limit : null + editQuotaWeeklyLimit.value = typeof bedrockExtra.quota_weekly_limit === 'number' ? bedrockExtra.quota_weekly_limit : null + + // Load model mappings for bedrock + const existingMappings = bedrockCreds.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } } else if (newAccount.type === 'upstream' && newAccount.credentials) { const credentials = newAccount.credentials as Record editBaseUrl.value = (credentials.base_url as string) || '' @@ -1602,9 +2212,35 @@ watch( ? 'https://generativelanguage.googleapis.com' : 'https://api.anthropic.com' editBaseUrl.value = platformDefaultUrl - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] + + // Load model mappings for OpenAI OAuth accounts + if (newAccount.platform === 'openai' && newAccount.credentials) { + const oauthCredentials = newAccount.credentials as Record + const existingMappings = oauthCredentials.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } + poolModeEnabled.value = false + poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT customErrorCodesEnabled.value = false selectedErrorCodes.value = [] } @@ -1810,6 +2446,7 @@ function loadQuotaControlSettings(account: Account) { baseRpm.value = null rpmStrategy.value = 'tiered' rpmStickyBuffer.value = null + userMsgQueueMode.value = '' tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false cacheTTLOverrideEnabled.value = false @@ -1841,6 +2478,9 @@ function loadQuotaControlSettings(account: Account) { rpmStickyBuffer.value = account.rpm_sticky_buffer ?? null } + // UMQ mode(独立于 RPM 加载,防止编辑无 RPM 账号时丢失已有配置) + userMsgQueueMode.value = account.user_msg_queue_mode ?? '' + // Load TLS fingerprint setting if (account.enable_tls_fingerprint === true) { tlsFingerprintEnabled.value = true @@ -2004,6 +2644,11 @@ const handleSubmit = async () => { if (!props.account) return const accountID = props.account.id + if (form.status !== 'active' && form.status !== 'inactive' && form.status !== 'error') { + appStore.showError(t('admin.accounts.pleaseSelectStatus')) + return + } + const updatePayload: Record = { ...form } try { // 后端期望 proxy_id: 0 表示清除代理,而不是 null @@ -2013,6 +2658,11 @@ const handleSubmit = async () => { if (form.expires_at === null) { updatePayload.expires_at = 0 } + // load_factor: 空值/NaN/0/负数 时发送 0(后端约定 <= 0 = 清除) + const lf = form.load_factor + if (lf == null || Number.isNaN(lf) || lf <= 0) { + updatePayload.load_factor = 0 + } updatePayload.auto_pause_on_expired = autoPauseOnExpired.value // For apikey type, handle credentials update @@ -2023,6 +2673,7 @@ const handleSubmit = async () => { // Always update credentials for apikey type to handle model mapping changes const newCredentials: Record = { + ...currentCredentials, base_url: newBaseUrl } @@ -2043,15 +2694,29 @@ const handleSubmit = async () => { const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) if (modelMapping) { newCredentials.model_mapping = modelMapping + } else { + delete newCredentials.model_mapping } } else if (currentCredentials.model_mapping) { newCredentials.model_mapping = currentCredentials.model_mapping } + // Add pool mode if enabled + if (poolModeEnabled.value) { + newCredentials.pool_mode = true + newCredentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + } else { + delete newCredentials.pool_mode + delete newCredentials.pool_mode_retry_count + } + // Add custom error codes if enabled if (customErrorCodesEnabled.value) { newCredentials.custom_error_codes_enabled = true newCredentials.custom_error_codes = [...selectedErrorCodes.value] + } else { + delete newCredentials.custom_error_codes_enabled + delete newCredentials.custom_error_codes } // Add intercept warmup requests setting @@ -2078,6 +2743,56 @@ const handleSubmit = async () => { return } + updatePayload.credentials = newCredentials + } else if (props.account.type === 'bedrock') { + const currentCredentials = (props.account.credentials as Record) || {} + const newCredentials: Record = { ...currentCredentials } + + newCredentials.aws_region = editBedrockRegion.value.trim() + if (editBedrockForceGlobal.value) { + newCredentials.aws_force_global = 'true' + } else { + delete newCredentials.aws_force_global + } + + if (isBedrockAPIKeyMode.value) { + // API Key mode: only update api_key if user provided new value + if (editBedrockApiKeyValue.value.trim()) { + newCredentials.api_key = editBedrockApiKeyValue.value.trim() + } + } else { + // SigV4 mode + newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim() + if (editBedrockSecretAccessKey.value.trim()) { + newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim() + } + if (editBedrockSessionToken.value.trim()) { + newCredentials.aws_session_token = editBedrockSessionToken.value.trim() + } + } + + // Pool mode + if (poolModeEnabled.value) { + newCredentials.pool_mode = true + newCredentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + } else { + delete newCredentials.pool_mode + delete newCredentials.pool_mode_retry_count + } + + // Model mapping + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + newCredentials.model_mapping = modelMapping + } else { + delete newCredentials.model_mapping + } + + applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') + if (!applyTempUnschedConfig(newCredentials)) { + return + } + updatePayload.credentials = newCredentials } else { // For oauth/setup-token types, only update intercept_warmup_requests if changed @@ -2092,6 +2807,28 @@ const handleSubmit = async () => { updatePayload.credentials = newCredentials } + // OpenAI OAuth: persist model mapping to credentials + if (props.account.platform === 'openai' && props.account.type === 'oauth') { + const currentCredentials = (updatePayload.credentials as Record) || + ((props.account.credentials as Record) || {}) + const newCredentials: Record = { ...currentCredentials } + const shouldApplyModelMapping = !openaiPassthroughEnabled.value + + if (shouldApplyModelMapping) { + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + newCredentials.model_mapping = modelMapping + } else { + delete newCredentials.model_mapping + } + } else if (currentCredentials.model_mapping) { + // 透传模式保留现有映射 + newCredentials.model_mapping = currentCredentials.model_mapping + } + + updatePayload.credentials = newCredentials + } + // Antigravity: persist model mapping to credentials (applies to all antigravity types) // Antigravity 只支持映射模式 if (props.account.platform === 'antigravity') { @@ -2116,7 +2853,7 @@ const handleSubmit = async () => { updatePayload.credentials = newCredentials } - // For antigravity accounts, handle mixed_scheduling in extra + // For antigravity accounts, handle mixed_scheduling and allow_overages in extra if (props.account.platform === 'antigravity') { const currentExtra = (props.account.extra as Record) || {} const newExtra: Record = { ...currentExtra } @@ -2125,6 +2862,11 @@ const handleSubmit = async () => { } else { delete newExtra.mixed_scheduling } + if (allowOverages.value) { + newExtra.allow_overages = true + } else { + delete newExtra.allow_overages + } updatePayload.extra = newExtra } @@ -2152,8 +2894,11 @@ const handleSubmit = async () => { } // RPM limit settings - if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) { - newExtra.base_rpm = baseRpm.value + if (rpmLimitEnabled.value) { + const DEFAULT_BASE_RPM = 15 + newExtra.base_rpm = (baseRpm.value != null && baseRpm.value > 0) + ? baseRpm.value + : DEFAULT_BASE_RPM newExtra.rpm_strategy = rpmStrategy.value if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) { newExtra.rpm_sticky_buffer = rpmStickyBuffer.value @@ -2166,6 +2911,14 @@ const handleSubmit = async () => { delete newExtra.rpm_sticky_buffer } + // UMQ mode(独立于 RPM 保存) + if (userMsgQueueMode.value) { + newExtra.user_msg_queue_mode = userMsgQueueMode.value + } else { + delete newExtra.user_msg_queue_mode + } + delete newExtra.user_msg_queue_enabled // 清理旧字段 + // TLS fingerprint setting if (tlsFingerprintEnabled.value) { newExtra.enable_tls_fingerprint = true @@ -2209,10 +2962,13 @@ const handleSubmit = async () => { const currentExtra = (props.account.extra as Record) || {} const newExtra: Record = { ...currentExtra } const hadCodexCLIOnlyEnabled = currentExtra.codex_cli_only === true - newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value - newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value - newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) - newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + if (props.account.type === 'oauth') { + newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value + newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) + } else if (props.account.type === 'apikey') { + newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value + newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + } delete newExtra.responses_websockets_v2_enabled delete newExtra.openai_ws_enabled if (openaiPassthroughEnabled.value) { @@ -2236,6 +2992,51 @@ const handleSubmit = async () => { updatePayload.extra = newExtra } + // For apikey/bedrock accounts, handle quota_limit in extra + if (props.account.type === 'apikey' || props.account.type === 'bedrock') { + const currentExtra = (updatePayload.extra as Record) || + (props.account.extra as Record) || {} + const newExtra: Record = { ...currentExtra } + if (editQuotaLimit.value != null && editQuotaLimit.value > 0) { + newExtra.quota_limit = editQuotaLimit.value + } else { + delete newExtra.quota_limit + } + if (editQuotaDailyLimit.value != null && editQuotaDailyLimit.value > 0) { + newExtra.quota_daily_limit = editQuotaDailyLimit.value + } else { + delete newExtra.quota_daily_limit + } + if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) { + newExtra.quota_weekly_limit = editQuotaWeeklyLimit.value + } else { + delete newExtra.quota_weekly_limit + } + // Quota reset mode config + if (editDailyResetMode.value === 'fixed') { + newExtra.quota_daily_reset_mode = 'fixed' + newExtra.quota_daily_reset_hour = editDailyResetHour.value ?? 0 + } else { + delete newExtra.quota_daily_reset_mode + delete newExtra.quota_daily_reset_hour + } + if (editWeeklyResetMode.value === 'fixed') { + newExtra.quota_weekly_reset_mode = 'fixed' + newExtra.quota_weekly_reset_day = editWeeklyResetDay.value ?? 1 + newExtra.quota_weekly_reset_hour = editWeeklyResetHour.value ?? 0 + } else { + delete newExtra.quota_weekly_reset_mode + delete newExtra.quota_weekly_reset_day + delete newExtra.quota_weekly_reset_hour + } + if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') { + newExtra.quota_reset_timezone = editResetTimezone.value || 'UTC' + } else { + delete newExtra.quota_reset_timezone + } + updatePayload.extra = newExtra + } + const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => { await submitUpdateAccount(accountID, updatePayload) }) diff --git a/frontend/src/components/account/ModelWhitelistSelector.vue b/frontend/src/components/account/ModelWhitelistSelector.vue index 16ffa225..ebce3740 100644 --- a/frontend/src/components/account/ModelWhitelistSelector.vue +++ b/frontend/src/components/account/ModelWhitelistSelector.vue @@ -131,7 +131,8 @@ const { t } = useI18n() const props = defineProps<{ modelValue: string[] - platform: string + platform?: string + platforms?: string[] }>() const emit = defineEmits<{ @@ -144,11 +145,36 @@ const showDropdown = ref(false) const searchQuery = ref('') const customModel = ref('') const isComposing = ref(false) +const normalizedPlatforms = computed(() => { + const rawPlatforms = + props.platforms && props.platforms.length > 0 + ? props.platforms + : props.platform + ? [props.platform] + : [] + + return Array.from( + new Set( + rawPlatforms + .map(platform => platform?.trim()) + .filter((platform): platform is string => Boolean(platform)) + ) + ) +}) + const availableOptions = computed(() => { - if (props.platform === 'sora') { - return getModelsByPlatform('sora').map(m => ({ value: m, label: m })) + if (normalizedPlatforms.value.length === 0) { + return allModels } - return allModels + + const allowedModels = new Set() + for (const platform of normalizedPlatforms.value) { + for (const model of getModelsByPlatform(platform)) { + allowedModels.add(model) + } + } + + return allModels.filter(model => allowedModels.has(model.value)) }) const filteredModels = computed(() => { @@ -192,10 +218,13 @@ const handleEnter = () => { } const fillRelated = () => { - const models = getModelsByPlatform(props.platform) const newModels = [...props.modelValue] - for (const model of models) { - if (!newModels.includes(model)) newModels.push(model) + for (const platform of normalizedPlatforms.value) { + for (const model of getModelsByPlatform(platform)) { + if (!newModels.includes(model)) { + newModels.push(model) + } + } } emit('update:modelValue', newModels) } diff --git a/frontend/src/components/account/QuotaBadge.vue b/frontend/src/components/account/QuotaBadge.vue new file mode 100644 index 00000000..7cf0f59d --- /dev/null +++ b/frontend/src/components/account/QuotaBadge.vue @@ -0,0 +1,49 @@ + + + diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue new file mode 100644 index 00000000..fdc19ad9 --- /dev/null +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -0,0 +1,295 @@ + + + diff --git a/frontend/src/components/account/TempUnschedStatusModal.vue b/frontend/src/components/account/TempUnschedStatusModal.vue index b2c0b71b..a3e64c48 100644 --- a/frontend/src/components/account/TempUnschedStatusModal.vue +++ b/frontend/src/components/account/TempUnschedStatusModal.vue @@ -29,6 +29,10 @@
+
+ {{ t('admin.accounts.recoverStateHint') }} +
+

{{ t('admin.accounts.tempUnschedulable.accountName') }} @@ -131,7 +135,7 @@ d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z" > - {{ t('admin.accounts.tempUnschedulable.reset') }} + {{ t('admin.accounts.recoverState') }}

@@ -154,7 +158,7 @@ const props = defineProps<{ const emit = defineEmits<{ close: [] - reset: [] + reset: [account: Account] }>() const { t } = useI18n() @@ -225,12 +229,12 @@ const handleReset = async () => { if (!props.account) return resetting.value = true try { - await adminAPI.accounts.resetTempUnschedulable(props.account.id) - appStore.showSuccess(t('admin.accounts.tempUnschedulable.resetSuccess')) - emit('reset') + const updated = await adminAPI.accounts.recoverState(props.account.id) + appStore.showSuccess(t('admin.accounts.recoverStateSuccess')) + emit('reset', updated) handleClose() } catch (error: any) { - appStore.showError(error?.message || t('admin.accounts.tempUnschedulable.resetFailed')) + appStore.showError(error?.message || t('admin.accounts.recoverStateFailed')) } finally { resetting.value = false } diff --git a/frontend/src/components/account/UsageProgressBar.vue b/frontend/src/components/account/UsageProgressBar.vue index 93844295..cd5c991f 100644 --- a/frontend/src/components/account/UsageProgressBar.vue +++ b/frontend/src/components/account/UsageProgressBar.vue @@ -1,21 +1,20 @@
@@ -51,7 +55,7 @@ import { Icon } from '@/components/icons' import type { Account } from '@/types' const props = defineProps<{ show: boolean; account: Account | null; position: { top: number; left: number } | null }>() -const emit = defineEmits(['close', 'test', 'stats', 'reauth', 'refresh-token', 'reset-status', 'clear-rate-limit']) +const emit = defineEmits(['close', 'test', 'stats', 'schedule', 'reauth', 'refresh-token', 'recover-state', 'reset-quota']) const { t } = useI18n() const isRateLimited = computed(() => { if (props.account?.rate_limit_reset_at && new Date(props.account.rate_limit_reset_at) > new Date()) { @@ -67,6 +71,17 @@ const isRateLimited = computed(() => { return false }) const isOverloaded = computed(() => props.account?.overload_until && new Date(props.account.overload_until) > new Date()) +const isTempUnschedulable = computed(() => props.account?.temp_unschedulable_until && new Date(props.account.temp_unschedulable_until) > new Date()) +const hasRecoverableState = computed(() => { + return props.account?.status === 'error' || Boolean(isRateLimited.value) || Boolean(isOverloaded.value) || Boolean(isTempUnschedulable.value) +}) +const hasQuotaLimit = computed(() => { + return (props.account?.type === 'apikey' || props.account?.type === 'bedrock') && ( + (props.account?.quota_limit ?? 0) > 0 || + (props.account?.quota_daily_limit ?? 0) > 0 || + (props.account?.quota_weekly_limit ?? 0) > 0 + ) +}) const handleKeydown = (event: KeyboardEvent) => { if (event.key === 'Escape') emit('close') diff --git a/frontend/src/components/admin/account/AccountBulkActionsBar.vue b/frontend/src/components/admin/account/AccountBulkActionsBar.vue index 41111484..3b987bd0 100644 --- a/frontend/src/components/admin/account/AccountBulkActionsBar.vue +++ b/frontend/src/components/admin/account/AccountBulkActionsBar.vue @@ -20,6 +20,8 @@
+ + @@ -29,5 +31,5 @@ \ No newline at end of file +defineProps(['selectedIds']); defineEmits(['delete', 'edit', 'clear', 'select-page', 'toggle-schedulable', 'reset-status', 'refresh-token']); const { t } = useI18n() + diff --git a/frontend/src/components/admin/account/AccountStatsModal.vue b/frontend/src/components/admin/account/AccountStatsModal.vue index 72a71d36..4dc84d5e 100644 --- a/frontend/src/components/admin/account/AccountStatsModal.vue +++ b/frontend/src/components/admin/account/AccountStatsModal.vue @@ -410,6 +410,18 @@ + + + + @@ -453,6 +465,7 @@ import { Line } from 'vue-chartjs' import BaseDialog from '@/components/common/BaseDialog.vue' import LoadingSpinner from '@/components/common/LoadingSpinner.vue' import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue' +import EndpointDistributionChart from '@/components/charts/EndpointDistributionChart.vue' import Icon from '@/components/icons/Icon.vue' import { adminAPI } from '@/api/admin' import type { Account, AccountUsageStatsResponse } from '@/types' diff --git a/frontend/src/components/admin/account/AccountTableFilters.vue b/frontend/src/components/admin/account/AccountTableFilters.vue index 5280e787..d8068336 100644 --- a/frontend/src/components/admin/account/AccountTableFilters.vue +++ b/frontend/src/components/admin/account/AccountTableFilters.vue @@ -24,7 +24,7 @@ const updateType = (value: string | number | boolean | null) => { emit('update:f const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) } const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) } const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }]) -const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }]) -const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }]) +const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }]) +const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }]) const gOpts = computed(() => [{ value: '', label: t('admin.accounts.allGroups') }, ...(props.groups || []).map(g => ({ value: String(g.id), label: g.name }))]) diff --git a/frontend/src/components/admin/account/AccountTestModal.vue b/frontend/src/components/admin/account/AccountTestModal.vue index a25c25cc..e731a7b1 100644 --- a/frontend/src/components/admin/account/AccountTestModal.vue +++ b/frontend/src/components/admin/account/AccountTestModal.vue @@ -61,6 +61,17 @@ {{ t('admin.accounts.soraTestHint') }}
+
+